initial whisper integration

This commit is contained in:
Concedo 2024-05-29 23:13:11 +08:00
parent 4ed9ba7352
commit f24aef8792
10 changed files with 16204 additions and 16 deletions

View file

@ -134,6 +134,23 @@ class sd_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data", ctypes.c_char_p)]
class whisper_load_model_inputs(ctypes.Structure):
_fields_ = [("model_filename", ctypes.c_char_p),
("executable_path", ctypes.c_char_p),
("clblast_info", ctypes.c_int),
("cublas_info", ctypes.c_int),
("vulkan_info", ctypes.c_char_p),
("debugmode", ctypes.c_int)]
class whisper_generation_inputs(ctypes.Structure):
_fields_ = [("prompt", ctypes.c_char_p),
("audio_data", ctypes.c_char_p),
("quiet", ctypes.c_bool)]
class whisper_generation_outputs(ctypes.Structure):
_fields_ = [("status", ctypes.c_int),
("data", ctypes.c_char_p)]
handle = None
def getdirpath():
@ -304,6 +321,10 @@ def init_library():
handle.sd_load_model.restype = ctypes.c_bool
handle.sd_generate.argtypes = [sd_generation_inputs]
handle.sd_generate.restype = sd_generation_outputs
handle.whisper_load_model.argtypes = [whisper_load_model_inputs]
handle.whisper_load_model.restype = ctypes.c_bool
handle.whisper_generate.argtypes = [whisper_generation_inputs]
handle.whisper_generate.restype = whisper_generation_outputs
def set_backend_props(inputs):
clblastids = 0
@ -612,6 +633,32 @@ def sd_generate(genparams):
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def whisper_load_model(model_filename):
global args
inputs = whisper_load_model_inputs()
inputs.debugmode = args.debugmode
inputs.executable_path = (getdirpath()+"/").encode("UTF-8")
inputs.model_filename = model_filename.encode("UTF-8")
inputs = set_backend_props(inputs)
ret = handle.whisper_load_model(inputs)
return ret
def whisper_generate(genparams):
global args
is_quiet = True if args.quiet else False
prompt = genparams.get("prompt", "")
audio_data = genparams.get("audio_data", "")
inputs = whisper_generation_inputs()
inputs.prompt = prompt.encode("UTF-8")
inputs.audio_data = audio_data.encode("UTF-8")
inputs.quiet = is_quiet
ret = handle.whisper_generate(inputs)
outstr = ""
if ret.status==1:
outstr = ret.data.decode("UTF-8","ignore")
return outstr
def utfprint(str):
maxlen = 99999
strlength = len(str)
@ -1547,7 +1594,7 @@ def show_new_gui():
root.quit()
if args.model_param and args.model_param!="" and args.model_param.lower().endswith('.kcpps'):
loadconfigfile(args.model_param)
if not args.model_param and not args.sdmodel:
if not args.model_param and not args.sdmodel and not args.whispermodel:
global exitcounter
exitcounter = 999
print("\nNo ggml model or kcpps file was selected. Exiting.")
@ -2568,13 +2615,13 @@ def show_new_gui():
if nextstate==0:
exitcounter = 999
print("Exiting by user request.")
time.sleep(3)
time.sleep(1)
sys.exit(0)
else:
# processing vars
export_vars()
if not args.model_param and not args.sdmodel:
if not args.model_param and not args.sdmodel and not args.whispermodel:
exitcounter = 999
print("\nNo text or image model file was selected. Exiting.")
time.sleep(3)
@ -3050,7 +3097,7 @@ def main(launch_args,start_server=True):
if not args.model_param:
args.model_param = args.model
if not args.model_param and not args.sdmodel:
if not args.model_param and not args.sdmodel and not args.whispermodel:
#give them a chance to pick a file
print("For command line arguments, please refer to --help")
print("***")
@ -3280,6 +3327,28 @@ def main(launch_args,start_server=True):
time.sleep(3)
sys.exit(3)
#handle whisper model
if args.whispermodel and args.whispermodel!="":
whispermodel = args.whispermodel
if not whispermodel or not os.path.exists(whispermodel):
print(f"Cannot find whisper model file: {whispermodel}")
if args.ignoremissing:
print(f"Ignoring missing whisper model file...")
args.whispermodel = None
else:
exitcounter = 999
time.sleep(3)
sys.exit(2)
else:
whispermodel = os.path.abspath(whispermodel)
loadok = whisper_load_model(whispermodel)
print("Load Whisper Model OK: " + str(loadok))
if not loadok:
exitcounter = 999
print("Could not load whisper model: " + imgmodel)
time.sleep(3)
sys.exit(3)
#load embedded lite
try:
basepath = os.path.abspath(os.path.dirname(__file__))
@ -3537,6 +3606,9 @@ if __name__ == '__main__':
sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify a stable diffusion LORA safetensors model to be applied. Cannot be used with quant models.", default="")
sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the LORA model to be applied.", type=float, default=1.0)
whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands')
whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper bin model to enable Speech-To-Text transcription.", default="")
deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!')
deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+')
deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+')