mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-10 17:14:36 +00:00
added extra endpoints for abort gen and polled streaming
This commit is contained in:
parent
5bd9cef9fa
commit
43f7e40470
5 changed files with 63 additions and 25 deletions
41
koboldcpp.py
41
koboldcpp.py
|
@ -45,7 +45,8 @@ class generation_inputs(ctypes.Structure):
|
|||
("mirostat", ctypes.c_int),
|
||||
("mirostat_tau", ctypes.c_float),
|
||||
("mirostat_eta", ctypes.c_float),
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max)]
|
||||
("stop_sequence", ctypes.c_char_p * stop_token_max),
|
||||
("stream_sse", ctypes.c_bool)]
|
||||
|
||||
class generation_outputs(ctypes.Structure):
|
||||
_fields_ = [("status", ctypes.c_int),
|
||||
|
@ -139,6 +140,8 @@ def init_library():
|
|||
handle.new_token.argtypes = [ctypes.c_int]
|
||||
handle.get_stream_count.restype = ctypes.c_int
|
||||
handle.has_finished.restype = ctypes.c_bool
|
||||
handle.abort_generate.restype = ctypes.c_bool
|
||||
handle.get_pending_output.restype = ctypes.c_char_p
|
||||
|
||||
def load_model(model_filename):
|
||||
inputs = load_model_inputs()
|
||||
|
@ -167,7 +170,7 @@ def load_model(model_filename):
|
|||
ret = handle.load_model(inputs)
|
||||
return ret
|
||||
|
||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[]):
|
||||
def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=120, top_a=0.0 ,top_p=0.85, typical_p=1.0, tfs=1.0 ,rep_pen=1.1,rep_pen_range=128,seed=-1,stop_sequence=[],stream_sse=False):
|
||||
inputs = generation_inputs()
|
||||
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
|
||||
inputs.prompt = prompt.encode("UTF-8")
|
||||
|
@ -181,6 +184,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k=
|
|||
inputs.tfs = tfs
|
||||
inputs.rep_pen = rep_pen
|
||||
inputs.rep_pen_range = rep_pen_range
|
||||
inputs.stream_sse = stream_sse
|
||||
if args.usemirostat and args.usemirostat[0]>0:
|
||||
inputs.mirostat = int(args.usemirostat[0])
|
||||
inputs.mirostat_tau = float(args.usemirostat[1])
|
||||
|
@ -215,7 +219,7 @@ maxctx = 2048
|
|||
maxlen = 256
|
||||
modelbusy = False
|
||||
defaultport = 5001
|
||||
KcppVersion = "1.29"
|
||||
KcppVersion = "1.30"
|
||||
|
||||
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
||||
sys_version = ""
|
||||
|
@ -229,7 +233,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
def __call__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def generate_text(self, newprompt, genparams, basic_api_flag):
|
||||
async def generate_text(self, newprompt, genparams, basic_api_flag, stream_flag):
|
||||
loop = asyncio.get_event_loop()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
|
@ -247,8 +251,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', [])
|
||||
)
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
stream_sse=stream_flag)
|
||||
|
||||
else:
|
||||
return generate(prompt=newprompt,
|
||||
max_context_length=genparams.get('max_context_length', maxctx),
|
||||
|
@ -262,8 +267,9 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
rep_pen=genparams.get('rep_pen', 1.1),
|
||||
rep_pen_range=genparams.get('rep_pen_range', 128),
|
||||
seed=genparams.get('sampler_seed', -1),
|
||||
stop_sequence=genparams.get('stop_sequence', [])
|
||||
)
|
||||
stop_sequence=genparams.get('stop_sequence', []),
|
||||
stream_sse=stream_flag)
|
||||
|
||||
|
||||
recvtxt = await loop.run_in_executor(executor, run_blocking)
|
||||
|
||||
|
@ -300,7 +306,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
|
||||
current_token += 1
|
||||
|
||||
tokenStr = ctypes.string_at(token).decode('utf-8')
|
||||
tokenStr = ctypes.string_at(token).decode("UTF-8","ignore")
|
||||
event_data = {"token": tokenStr}
|
||||
event_str = json.dumps(event_data)
|
||||
await self.send_sse_event("message", event_str)
|
||||
|
@ -319,7 +325,7 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
if stream_flag:
|
||||
tasks.append(self.handle_sse_stream())
|
||||
|
||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag))
|
||||
generate_task = asyncio.create_task(self.generate_text(newprompt, genparams, basic_api_flag, stream_flag))
|
||||
tasks.append(generate_task)
|
||||
|
||||
try:
|
||||
|
@ -395,6 +401,21 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
|
|||
kai_sse_stream_flag = False
|
||||
self.path = self.path.rstrip('/')
|
||||
|
||||
if self.path.endswith('/api/extra/abort'):
|
||||
ag = handle.abort_generate()
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"success": ("true" if ag else "false")}).encode())
|
||||
print("Generation Aborted")
|
||||
return
|
||||
|
||||
if self.path.endswith('/api/extra/generate/check'):
|
||||
pendtxt = handle.get_pending_output()
|
||||
pendtxtStr = ctypes.string_at(pendtxt).decode("UTF-8","ignore")
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
self.wfile.write(json.dumps({"results": [{"text": pendtxtStr}]}).encode())
|
||||
return
|
||||
|
||||
if modelbusy:
|
||||
self.send_response(503)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue