added whisper file upload mode

This commit is contained in:
Concedo 2024-06-02 12:04:56 +08:00
parent 7ef31e541c
commit 9e64f0b5af

View file

@ -9,9 +9,10 @@
# scenarios and everything Kobold and Kobold Lite have to offer. # scenarios and everything Kobold and Kobold Lite have to offer.
import ctypes import ctypes
import os, math import os, math, re
import argparse import argparse
import platform import platform
import base64
import json, sys, http.server, time, asyncio, socket, threading import json, sys, http.server, time, asyncio, socket, threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -844,6 +845,26 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
super().log_message(format, *args) super().log_message(format, *args)
pass pass
def extract_b64string_from_file_upload(self, body):
try:
if 'content-type' in self.headers and self.headers['content-type']:
boundary = self.headers['content-type'].split("=")[1].encode()
if boundary:
fparts = body.split(boundary)
for fpart in fparts:
detected_upload_filename = re.findall(r'Content-Disposition.*name="file"; filename="(.*)"', fpart.decode('utf-8',errors='ignore'))
if detected_upload_filename and len(detected_upload_filename)>0:
utfprint(f"Detected uploaded file: {detected_upload_filename[0]}")
file_data = fpart.split(b'\r\n\r\n')[1].rsplit(b'\r\n', 1)[0]
file_data_base64 = base64.b64encode(file_data).decode('utf-8')
base64_string = f"data:audio/wav;base64,{file_data_base64}"
return base64_string
print("Uploaded file not found.")
return None
except Exception as e:
print(f"File Upload Process Error: {e}")
return None
async def generate_text(self, genparams, api_format, stream_flag): async def generate_text(self, genparams, api_format, stream_flag):
from datetime import datetime from datetime import datetime
global friendlymodelname, chatcompl_adapter, currfinishreason global friendlymodelname, chatcompl_adapter, currfinishreason
@ -1425,7 +1446,7 @@ Enter Prompt:<br>
if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'): if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
is_imggen = True is_imggen = True
if self.path.endswith('/api/extra/transcribe'): if self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
is_transcribe = True is_transcribe = True
if is_imggen or is_transcribe or api_format > 0: if is_imggen or is_transcribe or api_format > 0:
@ -1440,14 +1461,21 @@ Enter Prompt:<br>
try: try:
genparams = json.loads(body) genparams = json.loads(body)
except Exception as e: except Exception as e:
utfprint("Body Err: " + str(body)) genparams = None
self.send_response(500) if is_transcribe: #fallback handling of file uploads
self.end_headers(content_type='application/json') b64wav = self.extract_b64string_from_file_upload(body)
self.wfile.write(json.dumps({"detail": { if b64wav:
"msg": "Error parsing input.", genparams = {"audio_data":b64wav}
"type": "bad_input",
}}).encode()) if not genparams:
return utfprint("Body Err: " + str(body))
self.send_response(500)
self.end_headers(content_type='application/json')
self.wfile.write(json.dumps({"detail": {
"msg": "Error parsing input.",
"type": "bad_input",
}}).encode())
return
is_quiet = args.quiet is_quiet = args.quiet
if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1: if (args.debugmode != -1 and not is_quiet) or args.debugmode >= 1:
@ -3463,7 +3491,7 @@ def main(launch_args,start_server=True):
from datetime import datetime, timezone from datetime import datetime, timezone
start_server = False start_server = False
save_to_file = (args.benchmark!="stdout" and args.benchmark!="") save_to_file = (args.benchmark!="stdout" and args.benchmark!="")
benchmaxctx = (16384 if maxctx>16384 else maxctx) benchmaxctx = maxctx
benchlen = 100 benchlen = 100
benchmodel = sanitize_string(os.path.splitext(os.path.basename(modelname))[0]) benchmodel = sanitize_string(os.path.splitext(os.path.basename(modelname))[0])
if os.path.exists(args.benchmark) and os.path.getsize(args.benchmark) > 1000000: if os.path.exists(args.benchmark) and os.path.getsize(args.benchmark) > 1000000:
@ -3475,7 +3503,7 @@ def main(launch_args,start_server=True):
print(f"\nRunning benchmark (Not Saved)...") print(f"\nRunning benchmark (Not Saved)...")
benchprompt = "1111111111111111" benchprompt = "1111111111111111"
for i in range(0,12): #generate massive prompt for i in range(0,14): #generate massive prompt
benchprompt += benchprompt benchprompt += benchprompt
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True) genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
result = genout['text'] result = genout['text']