added whisper file upload mode

This commit is contained in:
Concedo 2024-06-02 12:04:56 +08:00
parent 7ef31e541c
commit 9e64f0b5af

View file

@ -9,9 +9,10 @@
# scenarios and everything Kobold and Kobold Lite have to offer.
import ctypes
import os, math
import os, math, re
import argparse
import platform
import base64
import json, sys, http.server, time, asyncio, socket, threading
from concurrent.futures import ThreadPoolExecutor
@ -844,6 +845,26 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):
super().log_message(format, *args)
pass
def extract_b64string_from_file_upload(self, body):
try:
if 'content-type' in self.headers and self.headers['content-type']:
boundary = self.headers['content-type'].split("=")[1].encode()
if boundary:
fparts = body.split(boundary)
for fpart in fparts:
detected_upload_filename = re.findall(r'Content-Disposition.*name="file"; filename="(.*)"', fpart.decode('utf-8',errors='ignore'))
if detected_upload_filename and len(detected_upload_filename)>0:
utfprint(f"Detected uploaded file: {detected_upload_filename[0]}")
file_data = fpart.split(b'\r\n\r\n')[1].rsplit(b'\r\n', 1)[0]
file_data_base64 = base64.b64encode(file_data).decode('utf-8')
base64_string = f"data:audio/wav;base64,{file_data_base64}"
return base64_string
print("Uploaded file not found.")
return None
except Exception as e:
print(f"File Upload Process Error: {e}")
return None
async def generate_text(self, genparams, api_format, stream_flag):
from datetime import datetime
global friendlymodelname, chatcompl_adapter, currfinishreason
@ -1425,7 +1446,7 @@ Enter Prompt:<br>
if self.path.endswith('/sdapi/v1/txt2img') or self.path.endswith('/sdapi/v1/img2img'):
is_imggen = True
if self.path.endswith('/api/extra/transcribe'):
if self.path.endswith('/api/extra/transcribe') or self.path.endswith('/v1/audio/transcriptions'):
is_transcribe = True
if is_imggen or is_transcribe or api_format > 0:
@ -1440,6 +1461,13 @@ Enter Prompt:<br>
try:
genparams = json.loads(body)
except Exception as e:
genparams = None
if is_transcribe: #fallback handling of file uploads
b64wav = self.extract_b64string_from_file_upload(body)
if b64wav:
genparams = {"audio_data":b64wav}
if not genparams:
utfprint("Body Err: " + str(body))
self.send_response(500)
self.end_headers(content_type='application/json')
@ -3463,7 +3491,7 @@ def main(launch_args,start_server=True):
from datetime import datetime, timezone
start_server = False
save_to_file = (args.benchmark!="stdout" and args.benchmark!="")
benchmaxctx = (16384 if maxctx>16384 else maxctx)
benchmaxctx = maxctx
benchlen = 100
benchmodel = sanitize_string(os.path.splitext(os.path.basename(modelname))[0])
if os.path.exists(args.benchmark) and os.path.getsize(args.benchmark) > 1000000:
@ -3475,7 +3503,7 @@ def main(launch_args,start_server=True):
print(f"\nRunning benchmark (Not Saved)...")
benchprompt = "1111111111111111"
for i in range(0,12): #generate massive prompt
for i in range(0,14): #generate massive prompt
benchprompt += benchprompt
genout = generate(benchprompt,memory="",images=[],max_length=benchlen,max_context_length=benchmaxctx,temperature=0.1,top_k=1,rep_pen=1,use_default_badwordsids=True)
result = genout['text']