mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
implement sampler order, expose sampler order and mirostat in api
This commit is contained in:
parent
d6b47e6a5b
commit
309534dcd0
3 changed files with 84 additions and 8 deletions
24
koboldcpp.py
24
koboldcpp.py
|
@ -9,6 +9,7 @@ import json, sys, http.server, time, asyncio, socket, threading
|
|||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
stop_token_max = 10
|
||||
sampler_order_max = 7
|
||||
|
||||
class load_model_inputs(ctypes.Structure):
|
||||
_fields_ = [("threads", ctypes.c_int),
|
||||
|
@ -47,6 +48,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("mirostat", ctypes.c_int),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("sampler_order", ctypes.c_int * sampler_order_max),
|
||||
("sampler_len", ctypes.c_int),
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
||||
("stream_sse", ctypes.c_bool)]
|
||||
|
||||
|
@ -186,7 +189,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False):
|
||||
def generate(prompt,max_length=20, max_context_length=512, temperature=0.8, top_k=120, top_a=0.0, top_p=0.85, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=None, seed=-1, stop_sequence=[], stream_sse=False):
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -205,8 +208,19 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
|
|||
inputs.mirostat = int(args.usemirostat[0])
|
||||
inputs.mirostat_tau = float(args.usemirostat[1])
|
||||
inputs.mirostat_eta = float(args.usemirostat[2])
|
||||
elif mirostat in (1, 2):
|
||||
inputs.mirostat = mirostat
|
||||
inputs.mirostat_tau = mirostat_tau
|
||||
inputs.mirostat_eta = mirostat_eta
|
||||
else:
|
||||
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
|
||||
if sampler_order and 0 < len(sampler_order) <= sampler_order_max:
|
||||
try:
|
||||
for i, sampler in enumerate(sampler_order):
|
||||
inputs.sampler_order[i] = sampler
|
||||
inputs.sampler_len = len(sampler_order)
|
||||
except TypeError as e:
|
||||
print("ERROR: sampler_order must be a list of integers: " + str(e))
|
||||
inputs.seed = seed
|
||||
for n in range(stop_token_max):
|
||||
if not stop_sequence or n >= len(stop_sequence):
|
||||
|
@ -272,6 +286,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
sampler_order=genparams.get('sampler_order', None),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
stream_sse=stream_flag)
|
||||
|
@ -288,6 +306,10 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
tfs=genparams.get('tfs', 1.0),
|
||||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
mirostat=genparams.get('mirostat', 0),
|
||||
mirostat_tau=genparams.get('mirostat_tau', 5.0),
|
||||
mirostat_eta=genparams.get('mirostat_eta', 0.1),
|
||||
sampler_order=genparams.get('sampler_order', None),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
stream_sse=stream_flag)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue