diff --git a/Makefile b/Makefile index 6209c29b4..c1247f071 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ default: koboldcpp_default koboldcpp_failsafe koboldcpp_openblas koboldcpp_noavx2 koboldcpp_clblast koboldcpp_clblast_noavx2 koboldcpp_cublas koboldcpp_hipblas koboldcpp_vulkan koboldcpp_vulkan_noavx2 -tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip sdmain gguf-split +tools: quantize_gpt2 quantize_gptj quantize_gguf quantize_neox quantize_mpt quantize_clip whispermain sdmain gguf-split dev: koboldcpp_openblas dev2: koboldcpp_clblast @@ -512,6 +512,10 @@ sdcpp_default.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffu sdcpp_cublas.o: otherarch/sdcpp/sdtype_adapter.cpp otherarch/sdcpp/stable-diffusion.h otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/util.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@ +#whisper objects +whispercpp_default.o: otherarch/whispercpp/whisper_adapter.cpp + $(CXX) $(CXXFLAGS) -c $< -o $@ + # idiotic "for easier compilation" GPTTYPE_ADAPTER = gpttype_adapter.cpp otherarch/llama_v2.cpp otherarch/llama_v3.cpp llama.cpp otherarch/utils.cpp otherarch/gptj_v1.cpp otherarch/gptj_v2.cpp otherarch/gptj_v3.cpp otherarch/gpt2_v1.cpp otherarch/gpt2_v2.cpp otherarch/gpt2_v3.cpp otherarch/rwkv_v2.cpp otherarch/rwkv_v3.cpp otherarch/neox_v2.cpp otherarch/neox_v3.cpp otherarch/mpt_v3.cpp ggml.h ggml-cuda.h llama.h otherarch/llama-util.h gpttype_adapter_failsafe.o: $(GPTTYPE_ADAPTER) @@ -530,7 +534,7 @@ gpttype_adapter_vulkan_noavx2.o: $(GPTTYPE_ADAPTER) $(CXX) $(CXXFLAGS) $(FAILSAFE_FLAGS) $(VULKAN_FLAGS) -c $< -o $@ clean: - rm -vf *.o main sdmain quantize_gguf quantize_clip quantize_gpt2 quantize_gptj quantize_neox quantize_mpt quantize-stats perplexity embedding benchmark-matmult save-load-state gguf imatrix imatrix.exe gguf.exe main.exe quantize_clip.exe quantize_gguf.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_openblas.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_clblast_noavx2.dll koboldcpp_cublas.dll koboldcpp_hipblas.dll koboldcpp_vulkan.dll koboldcpp_vulkan_noavx2.dll koboldcpp_default.so koboldcpp_openblas.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_clblast_noavx2.so koboldcpp_cublas.so koboldcpp_hipblas.so koboldcpp_vulkan.so koboldcpp_vulkan_noavx2.so + rm -vf *.o main sdmain whispermain quantize_gguf quantize_clip quantize_gpt2 quantize_gptj quantize_neox quantize_mpt quantize-stats perplexity embedding benchmark-matmult save-load-state gguf imatrix imatrix.exe gguf.exe main.exe quantize_clip.exe quantize_gguf.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_openblas.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_clblast_noavx2.dll koboldcpp_cublas.dll koboldcpp_hipblas.dll koboldcpp_vulkan.dll koboldcpp_vulkan_noavx2.dll koboldcpp_default.so koboldcpp_openblas.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_clblast_noavx2.so koboldcpp_cublas.so koboldcpp_hipblas.so koboldcpp_vulkan.so koboldcpp_vulkan_noavx2.so rm -vrf ggml-cuda/*.o # useful tools @@ -539,6 +543,8 @@ main: examples/main/main.cpp build-info.h ggml.o llama.o console.o llavaclip_def @echo '==== Run ./main -h for help. ====' sdmain: otherarch/sdcpp/util.cpp otherarch/sdcpp/main.cpp otherarch/sdcpp/stable-diffusion.cpp otherarch/sdcpp/upscaler.cpp otherarch/sdcpp/model.cpp otherarch/sdcpp/thirdparty/zip.c build-info.h ggml.o llama.o console.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) +whispermain: otherarch/whispercpp/main.cpp otherarch/whispercpp/whisper.cpp build-info.h ggml.o llama.o console.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) imatrix: examples/imatrix/imatrix.cpp build-info.h ggml.o llama.o console.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) gguf: examples/gguf/gguf.cpp build-info.h ggml.o llama.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) @@ -548,11 +554,11 @@ gguf-split: examples/gguf-split/gguf-split.cpp ggml.o llama.o build-info.h llava #generated libraries -koboldcpp_default: ggml.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) +koboldcpp_default: ggml.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) $(DEFAULT_BUILD) ifdef OPENBLAS_BUILD -koboldcpp_openblas: ggml_v4_openblas.o ggml_v3_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) +koboldcpp_openblas: ggml_v4_openblas.o ggml_v3_openblas.o ggml_v2_openblas.o ggml_v1.o expose.o gpttype_adapter.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) $(OPENBLAS_BUILD) else koboldcpp_openblas: @@ -560,7 +566,7 @@ koboldcpp_openblas: endif ifdef FAILSAFE_BUILD -koboldcpp_failsafe: ggml_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FAILSAFE) $(OBJS) +koboldcpp_failsafe: ggml_v4_failsafe.o ggml_v3_failsafe.o ggml_v2_failsafe.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FAILSAFE) $(OBJS) $(FAILSAFE_BUILD) else koboldcpp_failsafe: @@ -568,7 +574,7 @@ koboldcpp_failsafe: endif ifdef NOAVX2_BUILD -koboldcpp_noavx2: ggml_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_SIMPLE) $(OBJS) +koboldcpp_noavx2: ggml_v4_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_failsafe.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_SIMPLE) $(OBJS) $(NOAVX2_BUILD) else koboldcpp_noavx2: @@ -576,10 +582,10 @@ koboldcpp_noavx2: endif ifdef CLBLAST_BUILD -koboldcpp_clblast: ggml_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) +koboldcpp_clblast: ggml_v4_clblast.o ggml_v3_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_FULL) $(OBJS) $(CLBLAST_BUILD) ifdef NOAVX2_BUILD -koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_SIMPLE) $(OBJS) +koboldcpp_clblast_noavx2: ggml_v4_clblast_noavx2.o ggml_v3_clblast_noavx2.o ggml_v2_clblast_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_clblast_noavx2.o ggml-opencl.o ggml_v3-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_default.o $(OBJS_SIMPLE) $(OBJS) $(CLBLAST_BUILD) else koboldcpp_clblast_noavx2: @@ -593,7 +599,7 @@ koboldcpp_clblast_noavx2: endif ifdef CUBLAS_BUILD -koboldcpp_cublas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o llavaclip_cublas.o llava.o ggml-backend_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS) +koboldcpp_cublas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o $(CUBLAS_OBJS) $(OBJS_FULL) $(OBJS) $(CUBLAS_BUILD) else koboldcpp_cublas: @@ -601,7 +607,7 @@ koboldcpp_cublas: endif ifdef HIPBLAS_BUILD -koboldcpp_hipblas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o llavaclip_cublas.o llava.o ggml-backend_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS) +koboldcpp_hipblas: ggml_v4_cublas.o ggml_v3_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o gpttype_adapter_cublas.o sdcpp_cublas.o whispercpp_default.o llavaclip_cublas.o llava.o ggml-backend_cublas.o $(HIP_OBJS) $(OBJS_FULL) $(OBJS) $(HIPBLAS_BUILD) else koboldcpp_hipblas: @@ -609,10 +615,10 @@ koboldcpp_hipblas: endif ifdef VULKAN_BUILD -koboldcpp_vulkan: ggml_v4_vulkan.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_vulkan.o $(OBJS_FULL) $(OBJS) +koboldcpp_vulkan: ggml_v4_vulkan.o ggml_v3.o ggml_v2.o ggml_v1.o expose.o gpttype_adapter_vulkan.o ggml-vulkan.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_vulkan.o $(OBJS_FULL) $(OBJS) $(VULKAN_BUILD) ifdef NOAVX2_BUILD -koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_default.o llavaclip_default.o llava.o ggml-backend_vulkan.o $(OBJS_SIMPLE) $(OBJS) +koboldcpp_vulkan_noavx2: ggml_v4_vulkan_noavx2.o ggml_v3_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o gpttype_adapter_vulkan_noavx2.o ggml-vulkan.o sdcpp_default.o whispercpp_default.o llavaclip_default.o llava.o ggml-backend_vulkan.o $(OBJS_SIMPLE) $(OBJS) $(VULKAN_BUILD) else koboldcpp_vulkan_noavx2: diff --git a/expose.cpp b/expose.cpp index b6327afd7..ec2506c49 100644 --- a/expose.cpp +++ b/expose.cpp @@ -221,6 +221,15 @@ extern "C" return sdtype_generate(inputs); } + bool whisper_load_model(const whisper_load_model_inputs inputs) + { + return whispertype_load_model(inputs); + } + whisper_generation_outputs whisper_generate(const whisper_generation_inputs inputs) + { + return whispertype_generate(inputs); + } + const char * new_token(int idx) { if (generated_tokens.size() <= idx || idx < 0) return nullptr; diff --git a/expose.h b/expose.h index 592b899bd..597c72e77 100644 --- a/expose.h +++ b/expose.h @@ -142,6 +142,26 @@ struct sd_generation_outputs int status = -1; const char * data = ""; }; +struct whisper_load_model_inputs +{ + const char * model_filename; + const char * executable_path; + const int clblast_info = 0; + const int cublas_info = 0; + const char * vulkan_info; + const int debugmode = 0; +}; +struct whisper_generation_inputs +{ + const char * prompt; + const char * audio_data; + const bool quiet = false; +}; +struct whisper_generation_outputs +{ + int status = -1; + const char * text = ""; +}; extern std::string executable_path; extern std::string lora_filename; diff --git a/koboldcpp.py b/koboldcpp.py index e8b93e7e5..7af163a3f 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -134,6 +134,23 @@ class sd_generation_outputs(ctypes.Structure): _fields_ = [("status", ctypes.c_int), ("data", ctypes.c_char_p)] +class whisper_load_model_inputs(ctypes.Structure): + _fields_ = [("model_filename", ctypes.c_char_p), + ("executable_path", ctypes.c_char_p), + ("clblast_info", ctypes.c_int), + ("cublas_info", ctypes.c_int), + ("vulkan_info", ctypes.c_char_p), + ("debugmode", ctypes.c_int)] + +class whisper_generation_inputs(ctypes.Structure): + _fields_ = [("prompt", ctypes.c_char_p), + ("audio_data", ctypes.c_char_p), + ("quiet", ctypes.c_bool)] + +class whisper_generation_outputs(ctypes.Structure): + _fields_ = [("status", ctypes.c_int), + ("data", ctypes.c_char_p)] + handle = None def getdirpath(): @@ -304,6 +321,10 @@ def init_library(): handle.sd_load_model.restype = ctypes.c_bool handle.sd_generate.argtypes = [sd_generation_inputs] handle.sd_generate.restype = sd_generation_outputs + handle.whisper_load_model.argtypes = [whisper_load_model_inputs] + handle.whisper_load_model.restype = ctypes.c_bool + handle.whisper_generate.argtypes = [whisper_generation_inputs] + handle.whisper_generate.restype = whisper_generation_outputs def set_backend_props(inputs): clblastids = 0 @@ -612,6 +633,32 @@ def sd_generate(genparams): outstr = ret.data.decode("UTF-8","ignore") return outstr + +def whisper_load_model(model_filename): + global args + inputs = whisper_load_model_inputs() + inputs.debugmode = args.debugmode + inputs.executable_path = (getdirpath()+"/").encode("UTF-8") + inputs.model_filename = model_filename.encode("UTF-8") + inputs = set_backend_props(inputs) + ret = handle.whisper_load_model(inputs) + return ret + +def whisper_generate(genparams): + global args + is_quiet = True if args.quiet else False + prompt = genparams.get("prompt", "") + audio_data = genparams.get("audio_data", "") + inputs = whisper_generation_inputs() + inputs.prompt = prompt.encode("UTF-8") + inputs.audio_data = audio_data.encode("UTF-8") + inputs.quiet = is_quiet + ret = handle.whisper_generate(inputs) + outstr = "" + if ret.status==1: + outstr = ret.data.decode("UTF-8","ignore") + return outstr + def utfprint(str): maxlen = 99999 strlength = len(str) @@ -1547,7 +1594,7 @@ def show_new_gui(): root.quit() if args.model_param and args.model_param!="" and args.model_param.lower().endswith('.kcpps'): loadconfigfile(args.model_param) - if not args.model_param and not args.sdmodel: + if not args.model_param and not args.sdmodel and not args.whispermodel: global exitcounter exitcounter = 999 print("\nNo ggml model or kcpps file was selected. Exiting.") @@ -2568,13 +2615,13 @@ def show_new_gui(): if nextstate==0: exitcounter = 999 print("Exiting by user request.") - time.sleep(3) + time.sleep(1) sys.exit(0) else: # processing vars export_vars() - if not args.model_param and not args.sdmodel: + if not args.model_param and not args.sdmodel and not args.whispermodel: exitcounter = 999 print("\nNo text or image model file was selected. Exiting.") time.sleep(3) @@ -3050,7 +3097,7 @@ def main(launch_args,start_server=True): if not args.model_param: args.model_param = args.model - if not args.model_param and not args.sdmodel: + if not args.model_param and not args.sdmodel and not args.whispermodel: #give them a chance to pick a file print("For command line arguments, please refer to --help") print("***") @@ -3280,6 +3327,28 @@ def main(launch_args,start_server=True): time.sleep(3) sys.exit(3) + #handle whisper model + if args.whispermodel and args.whispermodel!="": + whispermodel = args.whispermodel + if not whispermodel or not os.path.exists(whispermodel): + print(f"Cannot find whisper model file: {whispermodel}") + if args.ignoremissing: + print(f"Ignoring missing whisper model file...") + args.whispermodel = None + else: + exitcounter = 999 + time.sleep(3) + sys.exit(2) + else: + whispermodel = os.path.abspath(whispermodel) + loadok = whisper_load_model(whispermodel) + print("Load Whisper Model OK: " + str(loadok)) + if not loadok: + exitcounter = 999 + print("Could not load whisper model: " + imgmodel) + time.sleep(3) + sys.exit(3) + #load embedded lite try: basepath = os.path.abspath(os.path.dirname(__file__)) @@ -3537,6 +3606,9 @@ if __name__ == '__main__': sdparsergrouplora.add_argument("--sdlora", metavar=('[filename]'), help="Specify a stable diffusion LORA safetensors model to be applied. Cannot be used with quant models.", default="") sdparsergroup.add_argument("--sdloramult", metavar=('[amount]'), help="Multiplier for the LORA model to be applied.", type=float, default=1.0) + whisperparsergroup = parser.add_argument_group('Whisper Transcription Commands') + whisperparsergroup.add_argument("--whispermodel", metavar=('[filename]'), help="Specify a Whisper bin model to enable Speech-To-Text transcription.", default="") + deprecatedgroup = parser.add_argument_group('Deprecated Commands, DO NOT USE!') deprecatedgroup.add_argument("--hordeconfig", help=argparse.SUPPRESS, nargs='+') deprecatedgroup.add_argument("--sdconfig", help=argparse.SUPPRESS, nargs='+') diff --git a/model_adapter.h b/model_adapter.h index 8b0195664..7629a3f30 100644 --- a/model_adapter.h +++ b/model_adapter.h @@ -82,6 +82,9 @@ std::vector gpttype_get_token_arr(const std::string & input); bool sdtype_load_model(const sd_load_model_inputs inputs); sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs); +bool whispertype_load_model(const whisper_load_model_inputs inputs); +whisper_generation_outputs whispertype_generate(const whisper_generation_inputs inputs); + void timer_start(); double timer_check(); void print_tok_vec(std::vector &embd); diff --git a/otherarch/whispercpp/dr_wav.h b/otherarch/whispercpp/dr_wav.h new file mode 100644 index 000000000..612149bfe --- /dev/null +++ b/otherarch/whispercpp/dr_wav.h @@ -0,0 +1,6434 @@ +/* +WAV audio loader and writer. Choice of public domain or MIT-0. See license statements at the end of this file. +dr_wav - v0.12.16 - 2020-12-02 + +David Reid - mackron@gmail.com + +GitHub: https://github.com/mackron/dr_libs +*/ + +/* +RELEASE NOTES - VERSION 0.12 +============================ +Version 0.12 includes breaking changes to custom chunk handling. + + +Changes to Chunk Callback +------------------------- +dr_wav supports the ability to fire a callback when a chunk is encounted (except for WAVE and FMT chunks). The callback has been updated to include both the +container (RIFF or Wave64) and the FMT chunk which contains information about the format of the data in the wave file. + +Previously, there was no direct way to determine the container, and therefore no way to discriminate against the different IDs in the chunk header (RIFF and +Wave64 containers encode chunk ID's differently). The `container` parameter can be used to know which ID to use. + +Sometimes it can be useful to know the data format at the time the chunk callback is fired. A pointer to a `drwav_fmt` object is now passed into the chunk +callback which will give you information about the data format. To determine the sample format, use `drwav_fmt_get_format()`. This will return one of the +`DR_WAVE_FORMAT_*` tokens. +*/ + +/* +Introduction +============ +This is a single file library. To use it, do something like the following in one .c file. + + ```c + #define DR_WAV_IMPLEMENTATION + #include "dr_wav.h" + ``` + +You can then #include this file in other parts of the program as you would with any other header file. Do something like the following to read audio data: + + ```c + drwav wav; + if (!drwav_init_file(&wav, "my_song.wav", NULL)) { + // Error opening WAV file. + } + + drwav_int32* pDecodedInterleavedPCMFrames = malloc(wav.totalPCMFrameCount * wav.channels * sizeof(drwav_int32)); + size_t numberOfSamplesActuallyDecoded = drwav_read_pcm_frames_s32(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + + ... + + drwav_uninit(&wav); + ``` + +If you just want to quickly open and read the audio data in a single operation you can do something like this: + + ```c + unsigned int channels; + unsigned int sampleRate; + drwav_uint64 totalPCMFrameCount; + float* pSampleData = drwav_open_file_and_read_pcm_frames_f32("my_song.wav", &channels, &sampleRate, &totalPCMFrameCount, NULL); + if (pSampleData == NULL) { + // Error opening and reading WAV file. + } + + ... + + drwav_free(pSampleData); + ``` + +The examples above use versions of the API that convert the audio data to a consistent format (32-bit signed PCM, in this case), but you can still output the +audio data in its internal format (see notes below for supported formats): + + ```c + size_t framesRead = drwav_read_pcm_frames(&wav, wav.totalPCMFrameCount, pDecodedInterleavedPCMFrames); + ``` + +You can also read the raw bytes of audio data, which could be useful if dr_wav does not have native support for a particular data format: + + ```c + size_t bytesRead = drwav_read_raw(&wav, bytesToRead, pRawDataBuffer); + ``` + +dr_wav can also be used to output WAV files. This does not currently support compressed formats. To use this, look at `drwav_init_write()`, +`drwav_init_file_write()`, etc. Use `drwav_write_pcm_frames()` to write samples, or `drwav_write_raw()` to write raw data in the "data" chunk. + + ```c + drwav_data_format format; + format.container = drwav_container_riff; // <-- drwav_container_riff = normal WAV files, drwav_container_w64 = Sony Wave64. + format.format = DR_WAVE_FORMAT_PCM; // <-- Any of the DR_WAVE_FORMAT_* codes. + format.channels = 2; + format.sampleRate = 44100; + format.bitsPerSample = 16; + drwav_init_file_write(&wav, "data/recording.wav", &format, NULL); + + ... + + drwav_uint64 framesWritten = drwav_write_pcm_frames(pWav, frameCount, pSamples); + ``` + +dr_wav has seamless support the Sony Wave64 format. The decoder will automatically detect it and it should Just Work without any manual intervention. + + +Build Options +============= +#define these options before including this file. + +#define DR_WAV_NO_CONVERSION_API + Disables conversion APIs such as `drwav_read_pcm_frames_f32()` and `drwav_s16_to_f32()`. + +#define DR_WAV_NO_STDIO + Disables APIs that initialize a decoder from a file such as `drwav_init_file()`, `drwav_init_file_write()`, etc. + + + +Notes +===== +- Samples are always interleaved. +- The default read function does not do any data conversion. Use `drwav_read_pcm_frames_f32()`, `drwav_read_pcm_frames_s32()` and `drwav_read_pcm_frames_s16()` + to read and convert audio data to 32-bit floating point, signed 32-bit integer and signed 16-bit integer samples respectively. Tested and supported internal + formats include the following: + - Unsigned 8-bit PCM + - Signed 12-bit PCM + - Signed 16-bit PCM + - Signed 24-bit PCM + - Signed 32-bit PCM + - IEEE 32-bit floating point + - IEEE 64-bit floating point + - A-law and u-law + - Microsoft ADPCM + - IMA ADPCM (DVI, format code 0x11) +- dr_wav will try to read the WAV file as best it can, even if it's not strictly conformant to the WAV format. +*/ + +#ifndef dr_wav_h +#define dr_wav_h + +#ifdef __cplusplus +extern "C" { +#endif + +#define DRWAV_STRINGIFY(x) #x +#define DRWAV_XSTRINGIFY(x) DRWAV_STRINGIFY(x) + +#define DRWAV_VERSION_MAJOR 0 +#define DRWAV_VERSION_MINOR 12 +#define DRWAV_VERSION_REVISION 16 +#define DRWAV_VERSION_STRING DRWAV_XSTRINGIFY(DRWAV_VERSION_MAJOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_MINOR) "." DRWAV_XSTRINGIFY(DRWAV_VERSION_REVISION) + +#include /* For size_t. */ + +/* Sized types. */ +typedef signed char drwav_int8; +typedef unsigned char drwav_uint8; +typedef signed short drwav_int16; +typedef unsigned short drwav_uint16; +typedef signed int drwav_int32; +typedef unsigned int drwav_uint32; +#if defined(_MSC_VER) + typedef signed __int64 drwav_int64; + typedef unsigned __int64 drwav_uint64; +#else + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wlong-long" + #if defined(__clang__) + #pragma GCC diagnostic ignored "-Wc++11-long-long" + #endif + #endif + typedef signed long long drwav_int64; + typedef unsigned long long drwav_uint64; + #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) + #pragma GCC diagnostic pop + #endif +#endif +#if defined(__LP64__) || defined(_WIN64) || (defined(__x86_64__) && !defined(__ILP32__)) || defined(_M_X64) || defined(__ia64) || defined (_M_IA64) || defined(__aarch64__) || defined(__powerpc64__) + typedef drwav_uint64 drwav_uintptr; +#else + typedef drwav_uint32 drwav_uintptr; +#endif +typedef drwav_uint8 drwav_bool8; +typedef drwav_uint32 drwav_bool32; +#define DRWAV_TRUE 1 +#define DRWAV_FALSE 0 + +#if !defined(DRWAV_API) + #if defined(DRWAV_DLL) + #if defined(_WIN32) + #define DRWAV_DLL_IMPORT __declspec(dllimport) + #define DRWAV_DLL_EXPORT __declspec(dllexport) + #define DRWAV_DLL_PRIVATE static + #else + #if defined(__GNUC__) && __GNUC__ >= 4 + #define DRWAV_DLL_IMPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_EXPORT __attribute__((visibility("default"))) + #define DRWAV_DLL_PRIVATE __attribute__((visibility("hidden"))) + #else + #define DRWAV_DLL_IMPORT + #define DRWAV_DLL_EXPORT + #define DRWAV_DLL_PRIVATE static + #endif + #endif + + #if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) + #define DRWAV_API DRWAV_DLL_EXPORT + #else + #define DRWAV_API DRWAV_DLL_IMPORT + #endif + #define DRWAV_PRIVATE DRWAV_DLL_PRIVATE + #else + #define DRWAV_API extern + #define DRWAV_PRIVATE static + #endif +#endif + +typedef drwav_int32 drwav_result; +#define DRWAV_SUCCESS 0 +#define DRWAV_ERROR -1 /* A generic error. */ +#define DRWAV_INVALID_ARGS -2 +#define DRWAV_INVALID_OPERATION -3 +#define DRWAV_OUT_OF_MEMORY -4 +#define DRWAV_OUT_OF_RANGE -5 +#define DRWAV_ACCESS_DENIED -6 +#define DRWAV_DOES_NOT_EXIST -7 +#define DRWAV_ALREADY_EXISTS -8 +#define DRWAV_TOO_MANY_OPEN_FILES -9 +#define DRWAV_INVALID_FILE -10 +#define DRWAV_TOO_BIG -11 +#define DRWAV_PATH_TOO_LONG -12 +#define DRWAV_NAME_TOO_LONG -13 +#define DRWAV_NOT_DIRECTORY -14 +#define DRWAV_IS_DIRECTORY -15 +#define DRWAV_DIRECTORY_NOT_EMPTY -16 +#define DRWAV_END_OF_FILE -17 +#define DRWAV_NO_SPACE -18 +#define DRWAV_BUSY -19 +#define DRWAV_IO_ERROR -20 +#define DRWAV_INTERRUPT -21 +#define DRWAV_UNAVAILABLE -22 +#define DRWAV_ALREADY_IN_USE -23 +#define DRWAV_BAD_ADDRESS -24 +#define DRWAV_BAD_SEEK -25 +#define DRWAV_BAD_PIPE -26 +#define DRWAV_DEADLOCK -27 +#define DRWAV_TOO_MANY_LINKS -28 +#define DRWAV_NOT_IMPLEMENTED -29 +#define DRWAV_NO_MESSAGE -30 +#define DRWAV_BAD_MESSAGE -31 +#define DRWAV_NO_DATA_AVAILABLE -32 +#define DRWAV_INVALID_DATA -33 +#define DRWAV_TIMEOUT -34 +#define DRWAV_NO_NETWORK -35 +#define DRWAV_NOT_UNIQUE -36 +#define DRWAV_NOT_SOCKET -37 +#define DRWAV_NO_ADDRESS -38 +#define DRWAV_BAD_PROTOCOL -39 +#define DRWAV_PROTOCOL_UNAVAILABLE -40 +#define DRWAV_PROTOCOL_NOT_SUPPORTED -41 +#define DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED -42 +#define DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED -43 +#define DRWAV_SOCKET_NOT_SUPPORTED -44 +#define DRWAV_CONNECTION_RESET -45 +#define DRWAV_ALREADY_CONNECTED -46 +#define DRWAV_NOT_CONNECTED -47 +#define DRWAV_CONNECTION_REFUSED -48 +#define DRWAV_NO_HOST -49 +#define DRWAV_IN_PROGRESS -50 +#define DRWAV_CANCELLED -51 +#define DRWAV_MEMORY_ALREADY_MAPPED -52 +#define DRWAV_AT_END -53 + +/* Common data formats. */ +#define DR_WAVE_FORMAT_PCM 0x1 +#define DR_WAVE_FORMAT_ADPCM 0x2 +#define DR_WAVE_FORMAT_IEEE_FLOAT 0x3 +#define DR_WAVE_FORMAT_ALAW 0x6 +#define DR_WAVE_FORMAT_MULAW 0x7 +#define DR_WAVE_FORMAT_DVI_ADPCM 0x11 +#define DR_WAVE_FORMAT_EXTENSIBLE 0xFFFE + +/* Constants. */ +#ifndef DRWAV_MAX_SMPL_LOOPS +#define DRWAV_MAX_SMPL_LOOPS 1 +#endif + +/* Flags to pass into drwav_init_ex(), etc. */ +#define DRWAV_SEQUENTIAL 0x00000001 + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision); +DRWAV_API const char* drwav_version_string(void); + +typedef enum +{ + drwav_seek_origin_start, + drwav_seek_origin_current +} drwav_seek_origin; + +typedef enum +{ + drwav_container_riff, + drwav_container_w64, + drwav_container_rf64 +} drwav_container; + +typedef struct +{ + union + { + drwav_uint8 fourcc[4]; + drwav_uint8 guid[16]; + } id; + + /* The size in bytes of the chunk. */ + drwav_uint64 sizeInBytes; + + /* + RIFF = 2 byte alignment. + W64 = 8 byte alignment. + */ + unsigned int paddingSize; +} drwav_chunk_header; + +typedef struct +{ + /* + The format tag exactly as specified in the wave file's "fmt" chunk. This can be used by applications + that require support for data formats not natively supported by dr_wav. + */ + drwav_uint16 formatTag; + + /* The number of channels making up the audio data. When this is set to 1 it is mono, 2 is stereo, etc. */ + drwav_uint16 channels; + + /* The sample rate. Usually set to something like 44100. */ + drwav_uint32 sampleRate; + + /* Average bytes per second. You probably don't need this, but it's left here for informational purposes. */ + drwav_uint32 avgBytesPerSec; + + /* Block align. This is equal to the number of channels * bytes per sample. */ + drwav_uint16 blockAlign; + + /* Bits per sample. */ + drwav_uint16 bitsPerSample; + + /* The size of the extended data. Only used internally for validation, but left here for informational purposes. */ + drwav_uint16 extendedSize; + + /* + The number of valid bits per sample. When is equal to WAVE_FORMAT_EXTENSIBLE, + is always rounded up to the nearest multiple of 8. This variable contains information about exactly how + many bits are valid per sample. Mainly used for informational purposes. + */ + drwav_uint16 validBitsPerSample; + + /* The channel mask. Not used at the moment. */ + drwav_uint32 channelMask; + + /* The sub-format, exactly as specified by the wave file. */ + drwav_uint8 subFormat[16]; +} drwav_fmt; + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT); + + +/* +Callback for when data is read. Return value is the number of bytes actually read. + +pUserData [in] The user data that was passed to drwav_init() and family. +pBufferOut [out] The output buffer. +bytesToRead [in] The number of bytes to read. + +Returns the number of bytes actually read. + +A return value of less than bytesToRead indicates the end of the stream. Do _not_ return from this callback until +either the entire bytesToRead is filled or you have reached the end of the stream. +*/ +typedef size_t (* drwav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); + +/* +Callback for when data is written. Returns value is the number of bytes actually written. + +pUserData [in] The user data that was passed to drwav_init_write() and family. +pData [out] A pointer to the data to write. +bytesToWrite [in] The number of bytes to write. + +Returns the number of bytes actually written. + +If the return value differs from bytesToWrite, it indicates an error. +*/ +typedef size_t (* drwav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); + +/* +Callback for when data needs to be seeked. + +pUserData [in] The user data that was passed to drwav_init() and family. +offset [in] The number of bytes to move, relative to the origin. Will never be negative. +origin [in] The origin of the seek - the current position or the start of the stream. + +Returns whether or not the seek was successful. + +Whether or not it is relative to the beginning or current position is determined by the "origin" parameter which will be either drwav_seek_origin_start or +drwav_seek_origin_current. +*/ +typedef drwav_bool32 (* drwav_seek_proc)(void* pUserData, int offset, drwav_seek_origin origin); + +/* +Callback for when drwav_init_ex() finds a chunk. + +pChunkUserData [in] The user data that was passed to the pChunkUserData parameter of drwav_init_ex() and family. +onRead [in] A pointer to the function to call when reading. +onSeek [in] A pointer to the function to call when seeking. +pReadSeekUserData [in] The user data that was passed to the pReadSeekUserData parameter of drwav_init_ex() and family. +pChunkHeader [in] A pointer to an object containing basic header information about the chunk. Use this to identify the chunk. +container [in] Whether or not the WAV file is a RIFF or Wave64 container. If you're unsure of the difference, assume RIFF. +pFMT [in] A pointer to the object containing the contents of the "fmt" chunk. + +Returns the number of bytes read + seeked. + +To read data from the chunk, call onRead(), passing in pReadSeekUserData as the first parameter. Do the same for seeking with onSeek(). The return value must +be the total number of bytes you have read _plus_ seeked. + +Use the `container` argument to discriminate the fields in `pChunkHeader->id`. If the container is `drwav_container_riff` or `drwav_container_rf64` you should +use `id.fourcc`, otherwise you should use `id.guid`. + +The `pFMT` parameter can be used to determine the data format of the wave file. Use `drwav_fmt_get_format()` to get the sample format, which will be one of the +`DR_WAVE_FORMAT_*` identifiers. + +The read pointer will be sitting on the first byte after the chunk's header. You must not attempt to read beyond the boundary of the chunk. +*/ +typedef drwav_uint64 (* drwav_chunk_proc)(void* pChunkUserData, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_chunk_header* pChunkHeader, drwav_container container, const drwav_fmt* pFMT); + +typedef struct +{ + void* pUserData; + void* (* onMalloc)(size_t sz, void* pUserData); + void* (* onRealloc)(void* p, size_t sz, void* pUserData); + void (* onFree)(void* p, void* pUserData); +} drwav_allocation_callbacks; + +/* Structure for internal use. Only used for loaders opened with drwav_init_memory(). */ +typedef struct +{ + const drwav_uint8* data; + size_t dataSize; + size_t currentReadPos; +} drwav__memory_stream; + +/* Structure for internal use. Only used for writers opened with drwav_init_memory_write(). */ +typedef struct +{ + void** ppData; + size_t* pDataSize; + size_t dataSize; + size_t dataCapacity; + size_t currentWritePos; +} drwav__memory_stream_write; + +typedef struct +{ + drwav_container container; /* RIFF, W64. */ + drwav_uint32 format; /* DR_WAVE_FORMAT_* */ + drwav_uint32 channels; + drwav_uint32 sampleRate; + drwav_uint32 bitsPerSample; +} drwav_data_format; + + +/* See the following for details on the 'smpl' chunk: https://sites.google.com/site/musicgapi/technical-documents/wav-file-format#smpl */ +typedef struct +{ + drwav_uint32 cuePointId; + drwav_uint32 type; + drwav_uint32 start; + drwav_uint32 end; + drwav_uint32 fraction; + drwav_uint32 playCount; +} drwav_smpl_loop; + + typedef struct +{ + drwav_uint32 manufacturer; + drwav_uint32 product; + drwav_uint32 samplePeriod; + drwav_uint32 midiUnityNotes; + drwav_uint32 midiPitchFraction; + drwav_uint32 smpteFormat; + drwav_uint32 smpteOffset; + drwav_uint32 numSampleLoops; + drwav_uint32 samplerData; + drwav_smpl_loop loops[DRWAV_MAX_SMPL_LOOPS]; +} drwav_smpl; + +typedef struct +{ + /* A pointer to the function to call when more data is needed. */ + drwav_read_proc onRead; + + /* A pointer to the function to call when data needs to be written. Only used when the drwav object is opened in write mode. */ + drwav_write_proc onWrite; + + /* A pointer to the function to call when the wav file needs to be seeked. */ + drwav_seek_proc onSeek; + + /* The user data to pass to callbacks. */ + void* pUserData; + + /* Allocation callbacks. */ + drwav_allocation_callbacks allocationCallbacks; + + + /* Whether or not the WAV file is formatted as a standard RIFF file or W64. */ + drwav_container container; + + + /* Structure containing format information exactly as specified by the wav file. */ + drwav_fmt fmt; + + /* The sample rate. Will be set to something like 44100. */ + drwav_uint32 sampleRate; + + /* The number of channels. This will be set to 1 for monaural streams, 2 for stereo, etc. */ + drwav_uint16 channels; + + /* The bits per sample. Will be set to something like 16, 24, etc. */ + drwav_uint16 bitsPerSample; + + /* Equal to fmt.formatTag, or the value specified by fmt.subFormat if fmt.formatTag is equal to 65534 (WAVE_FORMAT_EXTENSIBLE). */ + drwav_uint16 translatedFormatTag; + + /* The total number of PCM frames making up the audio data. */ + drwav_uint64 totalPCMFrameCount; + + + /* The size in bytes of the data chunk. */ + drwav_uint64 dataChunkDataSize; + + /* The position in the stream of the first byte of the data chunk. This is used for seeking. */ + drwav_uint64 dataChunkDataPos; + + /* The number of bytes remaining in the data chunk. */ + drwav_uint64 bytesRemaining; + + + /* + Only used in sequential write mode. Keeps track of the desired size of the "data" chunk at the point of initialization time. Always + set to 0 for non-sequential writes and when the drwav object is opened in read mode. Used for validation. + */ + drwav_uint64 dataChunkDataSizeTargetWrite; + + /* Keeps track of whether or not the wav writer was initialized in sequential mode. */ + drwav_bool32 isSequentialWrite; + + + /* smpl chunk. */ + drwav_smpl smpl; + + + /* A hack to avoid a DRWAV_MALLOC() when opening a decoder with drwav_init_memory(). */ + drwav__memory_stream memoryStream; + drwav__memory_stream_write memoryStreamWrite; + + /* Generic data for compressed formats. This data is shared across all block-compressed formats. */ + struct + { + drwav_uint64 iCurrentPCMFrame; /* The index of the next PCM frame that will be read by drwav_read_*(). This is used with "totalPCMFrameCount" to ensure we don't read excess samples at the end of the last block. */ + } compressed; + + /* Microsoft ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_uint16 predictor[2]; + drwav_int32 delta[2]; + drwav_int32 cachedFrames[4]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + drwav_int32 prevFrames[2][2]; /* The previous 2 samples for each channel (2 channels at most). */ + } msadpcm; + + /* IMA ADPCM specific data. */ + struct + { + drwav_uint32 bytesRemainingInBlock; + drwav_int32 predictor[2]; + drwav_int32 stepIndex[2]; + drwav_int32 cachedFrames[16]; /* Samples are stored in this cache during decoding. */ + drwav_uint32 cachedFrameCount; + } ima; +} drwav; + + +/* +Initializes a pre-allocated drwav object for reading. + +pWav [out] A pointer to the drwav object being initialized. +onRead [in] The function to call when data needs to be read from the client. +onSeek [in] The function to call when the read position of the client data needs to move. +onChunk [in, optional] The function to call when a chunk is enumerated at initialized time. +pUserData, pReadSeekUserData [in, optional] A pointer to application defined data that will be passed to onRead and onSeek. +pChunkUserData [in, optional] A pointer to application defined data that will be passed to onChunk. +flags [in, optional] A set of flags for controlling how things are loaded. + +Returns true if successful; false otherwise. + +Close the loader with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file() and drwav_init_memory() +to open the stream from a file or from a block of memory respectively. + +Possible values for flags: + DRWAV_SEQUENTIAL: Never perform a backwards seek while loading. This disables the chunk callback and will cause this function + to return as soon as the data chunk is found. Any chunks after the data chunk will be ignored. + +drwav_init() is equivalent to "drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0);". + +The onChunk callback is not called for the WAVE or FMT chunks. The contents of the FMT chunk can be read from pWav->fmt +after the function returns. + +See also: drwav_init_file(), drwav_init_memory(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Initializes a pre-allocated drwav object for writing. + +onWrite [in] The function to call when data needs to be written. +onSeek [in] The function to call when the write position needs to move. +pUserData [in, optional] A pointer to application defined data that will be passed to onWrite and onSeek. + +Returns true if successful; false otherwise. + +Close the writer with drwav_uninit(). + +This is the lowest level function for initializing a WAV file. You can also use drwav_init_file_write() and drwav_init_memory_write() +to open the stream from a file or from a block of memory respectively. + +If the total sample count is known, you can use drwav_init_write_sequential(). This avoids the need for dr_wav to perform +a post-processing step for storing the total sample count and the size of the data chunk which requires a backwards seek. + +See also: drwav_init_file_write(), drwav_init_memory_write(), drwav_uninit() +*/ +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Utility function to determine the target size of the entire data to be written (including all headers and chunks). + +Returns the target size in bytes. + +Useful if the application needs to know the size to allocate. + +Only writing to the RIFF chunk and one data chunk is currently supported. + +See also: drwav_init_write(), drwav_init_file_write(), drwav_init_memory_write() +*/ +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +/* +Uninitializes the given drwav object. + +Use this only for objects initialized with drwav_init*() functions (drwav_init(), drwav_init_ex(), drwav_init_write(), drwav_init_write_sequential()). +*/ +DRWAV_API drwav_result drwav_uninit(drwav* pWav); + + +/* +Reads raw audio data. + +This is the lowest level function for reading audio data. It simply reads the given number of +bytes of the raw internal sample data. + +Consider using drwav_read_pcm_frames_s16(), drwav_read_pcm_frames_s32() or drwav_read_pcm_frames_f32() for +reading sample data in a consistent format. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of bytes actually read. +*/ +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut); + +/* +Reads up to the specified number of PCM frames from the WAV file. + +The output data will be in the file's internal format, converted to native-endian byte order. Use +drwav_read_pcm_frames_s16/f32/s32() to read data in a specific format. + +If the return value is less than it means the end of the file has been reached or +you have requested more PCM frames than can possibly fit in the output buffer. + +This function will only work when sample data is of a fixed size and uncompressed. If you are +using a compressed format consider using drwav_read_raw() or drwav_read_pcm_frames_s16/s32/f32(). + +pBufferOut can be NULL in which case a seek will be performed. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut); + +/* +Seeks to the given PCM frame. + +Returns true if successful; false otherwise. +*/ +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex); + + +/* +Writes raw audio data. + +Returns the number of bytes actually written. If this differs from bytesToWrite, it indicates an error. +*/ +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData); + +/* +Writes PCM frames. + +Returns the number of PCM frames written. + +Input samples need to be in native-endian byte order. On big-endian architectures the input data will be converted to +little-endian. Use drwav_write_raw() to write raw audio data without performing any conversion. +*/ +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData); + + +/* Conversion Utilities */ +#ifndef DR_WAV_NO_CONVERSION_API + +/* +Reads a chunk of audio data and converts it to signed 16-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 16-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to IEEE 32-bit floating point samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 32-bit PCM samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to IEEE 32-bit floating point samples. */ +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount); + + +/* +Reads a chunk of audio data and converts it to signed 32-bit PCM samples. + +pBufferOut can be NULL in which case a seek will be performed. + +Returns the number of PCM frames actually read. + +If the return value is less than it means the end of the file has been reached. +*/ +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut); + +/* Low-level function for converting unsigned 8-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting signed 16-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount); + +/* Low-level function for converting signed 24-bit PCM samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 32-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount); + +/* Low-level function for converting IEEE 64-bit floating point samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount); + +/* Low-level function for converting A-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +/* Low-level function for converting u-law samples to signed 32-bit PCM samples. */ +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount); + +#endif /* DR_WAV_NO_CONVERSION_API */ + + +/* High-Level Convenience Helpers */ + +#ifndef DR_WAV_NO_STDIO +/* +Helper for initializing a wave file for reading using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a wave file for writing using stdio. + +This holds the internal FILE object until drwav_uninit() is called. Keep this in mind if you're caching drwav +objects because the operating system may restrict the number of file handles an application can have open at +any given time. +*/ +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif /* DR_WAV_NO_STDIO */ + +/* +Helper for initializing a loader from a pre-allocated memory buffer. + +This does not create a copy of the data. It is up to the application to ensure the buffer remains valid for +the lifetime of the drwav object. + +The buffer should contain the contents of the entire wave file, not just the sample data. +*/ +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* +Helper for initializing a writer which outputs data to a memory buffer. + +dr_wav will manage the memory allocations, however it is up to the caller to free the data with drwav_free(). + +The buffer will remain allocated even after drwav_uninit() is called. The buffer should not be considered valid +until after drwav_uninit() has been called. +*/ +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks); + + +#ifndef DR_WAV_NO_CONVERSION_API +/* +Opens and reads an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#ifndef DR_WAV_NO_STDIO +/* +Opens and decodes an entire wav file in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif +/* +Opens and decodes an entire wav file from a block of memory in a single operation. + +The return value is a heap-allocated buffer containing the audio data. Use drwav_free() to free the buffer. +*/ +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks); +#endif + +/* Frees data that was allocated internally by dr_wav. */ +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks); + +/* Converts bytes from a wav stream to a sized type of native endian. */ +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data); +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data); +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data); +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data); +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data); +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data); + +/* Compares a GUID for the purpose of checking the type of a Wave64 chunk. */ +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]); + +/* Compares a four-character-code for the purpose of checking the type of a RIFF chunk. */ +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b); + +#ifdef __cplusplus +} +#endif +#endif /* dr_wav_h */ + + +/************************************************************************************************************************************************************ + ************************************************************************************************************************************************************ + + IMPLEMENTATION + + ************************************************************************************************************************************************************ + ************************************************************************************************************************************************************/ +#if defined(DR_WAV_IMPLEMENTATION) || defined(DRWAV_IMPLEMENTATION) +#ifndef dr_wav_c +#define dr_wav_c + +#include +#include /* For memcpy(), memset() */ +#include /* For INT_MAX */ + +#ifndef DR_WAV_NO_STDIO +#include +#include +#endif + +/* Standard library stuff. */ +#ifndef DRWAV_ASSERT +#include +#define DRWAV_ASSERT(expression) assert(expression) +#endif +#ifndef DRWAV_MALLOC +#define DRWAV_MALLOC(sz) malloc((sz)) +#endif +#ifndef DRWAV_REALLOC +#define DRWAV_REALLOC(p, sz) realloc((p), (sz)) +#endif +#ifndef DRWAV_FREE +#define DRWAV_FREE(p) free((p)) +#endif +#ifndef DRWAV_COPY_MEMORY +#define DRWAV_COPY_MEMORY(dst, src, sz) memcpy((dst), (src), (sz)) +#endif +#ifndef DRWAV_ZERO_MEMORY +#define DRWAV_ZERO_MEMORY(p, sz) memset((p), 0, (sz)) +#endif +#ifndef DRWAV_ZERO_OBJECT +#define DRWAV_ZERO_OBJECT(p) DRWAV_ZERO_MEMORY((p), sizeof(*p)) +#endif + +#define drwav_countof(x) (sizeof(x) / sizeof(x[0])) +#define drwav_align(x, a) ((((x) + (a) - 1) / (a)) * (a)) +#define drwav_min(a, b) (((a) < (b)) ? (a) : (b)) +#define drwav_max(a, b) (((a) > (b)) ? (a) : (b)) +#define drwav_clamp(x, lo, hi) (drwav_max((lo), drwav_min((hi), (x)))) + +#define DRWAV_MAX_SIMD_VECTOR_SIZE 64 /* 64 for AVX-512 in the future. */ + +/* CPU architecture. */ +#if defined(__x86_64__) || defined(_M_X64) + #define DRWAV_X64 +#elif defined(__i386) || defined(_M_IX86) + #define DRWAV_X86 +#elif defined(__arm__) || defined(_M_ARM) + #define DRWAV_ARM +#endif + +#ifdef _MSC_VER + #define DRWAV_INLINE __forceinline +#elif defined(__GNUC__) + /* + I've had a bug report where GCC is emitting warnings about functions possibly not being inlineable. This warning happens when + the __attribute__((always_inline)) attribute is defined without an "inline" statement. I think therefore there must be some + case where "__inline__" is not always defined, thus the compiler emitting these warnings. When using -std=c89 or -ansi on the + command line, we cannot use the "inline" keyword and instead need to use "__inline__". In an attempt to work around this issue + I am using "__inline__" only when we're compiling in strict ANSI mode. + */ + #if defined(__STRICT_ANSI__) + #define DRWAV_INLINE __inline__ __attribute__((always_inline)) + #else + #define DRWAV_INLINE inline __attribute__((always_inline)) + #endif +#elif defined(__WATCOMC__) + #define DRWAV_INLINE __inline +#else + #define DRWAV_INLINE +#endif + +#if defined(SIZE_MAX) + #define DRWAV_SIZE_MAX SIZE_MAX +#else + #if defined(_WIN64) || defined(_LP64) || defined(__LP64__) + #define DRWAV_SIZE_MAX ((drwav_uint64)0xFFFFFFFFFFFFFFFF) + #else + #define DRWAV_SIZE_MAX 0xFFFFFFFF + #endif +#endif + +#if defined(_MSC_VER) && _MSC_VER >= 1400 + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC +#elif defined(__clang__) + #if defined(__has_builtin) + #if __has_builtin(__builtin_bswap16) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap32) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #endif + #if __has_builtin(__builtin_bswap64) + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #endif +#elif defined(__GNUC__) + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 3)) + #define DRWAV_HAS_BYTESWAP32_INTRINSIC + #define DRWAV_HAS_BYTESWAP64_INTRINSIC + #endif + #if ((__GNUC__ > 4) || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)) + #define DRWAV_HAS_BYTESWAP16_INTRINSIC + #endif +#endif + +DRWAV_API void drwav_version(drwav_uint32* pMajor, drwav_uint32* pMinor, drwav_uint32* pRevision) +{ + if (pMajor) { + *pMajor = DRWAV_VERSION_MAJOR; + } + + if (pMinor) { + *pMinor = DRWAV_VERSION_MINOR; + } + + if (pRevision) { + *pRevision = DRWAV_VERSION_REVISION; + } +} + +DRWAV_API const char* drwav_version_string(void) +{ + return DRWAV_VERSION_STRING; +} + +/* +These limits are used for basic validation when initializing the decoder. If you exceed these limits, first of all: what on Earth are +you doing?! (Let me know, I'd be curious!) Second, you can adjust these by #define-ing them before the dr_wav implementation. +*/ +#ifndef DRWAV_MAX_SAMPLE_RATE +#define DRWAV_MAX_SAMPLE_RATE 384000 +#endif +#ifndef DRWAV_MAX_CHANNELS +#define DRWAV_MAX_CHANNELS 256 +#endif +#ifndef DRWAV_MAX_BITS_PER_SAMPLE +#define DRWAV_MAX_BITS_PER_SAMPLE 64 +#endif + +static const drwav_uint8 drwavGUID_W64_RIFF[16] = {0x72,0x69,0x66,0x66, 0x2E,0x91, 0xCF,0x11, 0xA5,0xD6, 0x28,0xDB,0x04,0xC1,0x00,0x00}; /* 66666972-912E-11CF-A5D6-28DB04C10000 */ +static const drwav_uint8 drwavGUID_W64_WAVE[16] = {0x77,0x61,0x76,0x65, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 65766177-ACF3-11D3-8CD1-00C04F8EDB8A */ +/*static const drwav_uint8 drwavGUID_W64_JUNK[16] = {0x6A,0x75,0x6E,0x6B, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A};*/ /* 6B6E756A-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FMT [16] = {0x66,0x6D,0x74,0x20, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 20746D66-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_FACT[16] = {0x66,0x61,0x63,0x74, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 74636166-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_DATA[16] = {0x64,0x61,0x74,0x61, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 61746164-ACF3-11D3-8CD1-00C04F8EDB8A */ +static const drwav_uint8 drwavGUID_W64_SMPL[16] = {0x73,0x6D,0x70,0x6C, 0xF3,0xAC, 0xD3,0x11, 0x8C,0xD1, 0x00,0xC0,0x4F,0x8E,0xDB,0x8A}; /* 6C706D73-ACF3-11D3-8CD1-00C04F8EDB8A */ + +static DRWAV_INLINE drwav_bool32 drwav__guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + int i; + for (i = 0; i < 16; i += 1) { + if (a[i] != b[i]) { + return DRWAV_FALSE; + } + } + + return DRWAV_TRUE; +} + +static DRWAV_INLINE drwav_bool32 drwav__fourcc_equal(const drwav_uint8* a, const char* b) +{ + return + a[0] == b[0] && + a[1] == b[1] && + a[2] == b[2] && + a[3] == b[3]; +} + + + +static DRWAV_INLINE int drwav__is_little_endian(void) +{ +#if defined(DRWAV_X86) || defined(DRWAV_X64) + return DRWAV_TRUE; +#elif defined(__BYTE_ORDER) && defined(__LITTLE_ENDIAN) && __BYTE_ORDER == __LITTLE_ENDIAN + return DRWAV_TRUE; +#else + int n = 1; + return (*(char*)&n) == 1; +#endif +} + +static DRWAV_INLINE drwav_uint16 drwav__bytes_to_u16(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8); +} + +static DRWAV_INLINE drwav_int16 drwav__bytes_to_s16(const drwav_uint8* data) +{ + return (short)drwav__bytes_to_u16(data); +} + +static DRWAV_INLINE drwav_uint32 drwav__bytes_to_u32(const drwav_uint8* data) +{ + return (data[0] << 0) | (data[1] << 8) | (data[2] << 16) | (data[3] << 24); +} + +static DRWAV_INLINE drwav_int32 drwav__bytes_to_s32(const drwav_uint8* data) +{ + return (drwav_int32)drwav__bytes_to_u32(data); +} + +static DRWAV_INLINE drwav_uint64 drwav__bytes_to_u64(const drwav_uint8* data) +{ + return + ((drwav_uint64)data[0] << 0) | ((drwav_uint64)data[1] << 8) | ((drwav_uint64)data[2] << 16) | ((drwav_uint64)data[3] << 24) | + ((drwav_uint64)data[4] << 32) | ((drwav_uint64)data[5] << 40) | ((drwav_uint64)data[6] << 48) | ((drwav_uint64)data[7] << 56); +} + +static DRWAV_INLINE drwav_int64 drwav__bytes_to_s64(const drwav_uint8* data) +{ + return (drwav_int64)drwav__bytes_to_u64(data); +} + +static DRWAV_INLINE void drwav__bytes_to_guid(const drwav_uint8* data, drwav_uint8* guid) +{ + int i; + for (i = 0; i < 16; ++i) { + guid[i] = data[i]; + } +} + + +static DRWAV_INLINE drwav_uint16 drwav__bswap16(drwav_uint16 n) +{ +#ifdef DRWAV_HAS_BYTESWAP16_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ushort(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap16(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF00) >> 8) | + ((n & 0x00FF) << 8); +#endif +} + +static DRWAV_INLINE drwav_uint32 drwav__bswap32(drwav_uint32 n) +{ +#ifdef DRWAV_HAS_BYTESWAP32_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_ulong(n); + #elif defined(__GNUC__) || defined(__clang__) + #if defined(DRWAV_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 6) && !defined(DRWAV_64BIT) /* <-- 64-bit inline assembly has not been tested, so disabling for now. */ + /* Inline assembly optimized implementation for ARM. In my testing, GCC does not generate optimized code with __builtin_bswap32(). */ + drwav_uint32 r; + __asm__ __volatile__ ( + #if defined(DRWAV_64BIT) + "rev %w[out], %w[in]" : [out]"=r"(r) : [in]"r"(n) /* <-- This is untested. If someone in the community could test this, that would be appreciated! */ + #else + "rev %[out], %[in]" : [out]"=r"(r) : [in]"r"(n) + #endif + ); + return r; + #else + return __builtin_bswap32(n); + #endif + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + return ((n & 0xFF000000) >> 24) | + ((n & 0x00FF0000) >> 8) | + ((n & 0x0000FF00) << 8) | + ((n & 0x000000FF) << 24); +#endif +} + +static DRWAV_INLINE drwav_uint64 drwav__bswap64(drwav_uint64 n) +{ +#ifdef DRWAV_HAS_BYTESWAP64_INTRINSIC + #if defined(_MSC_VER) + return _byteswap_uint64(n); + #elif defined(__GNUC__) || defined(__clang__) + return __builtin_bswap64(n); + #else + #error "This compiler does not support the byte swap intrinsic." + #endif +#else + /* Weird "<< 32" bitshift is required for C89 because it doesn't support 64-bit constants. Should be optimized out by a good compiler. */ + return ((n & ((drwav_uint64)0xFF000000 << 32)) >> 56) | + ((n & ((drwav_uint64)0x00FF0000 << 32)) >> 40) | + ((n & ((drwav_uint64)0x0000FF00 << 32)) >> 24) | + ((n & ((drwav_uint64)0x000000FF << 32)) >> 8) | + ((n & ((drwav_uint64)0xFF000000 )) << 8) | + ((n & ((drwav_uint64)0x00FF0000 )) << 24) | + ((n & ((drwav_uint64)0x0000FF00 )) << 40) | + ((n & ((drwav_uint64)0x000000FF )) << 56); +#endif +} + + +static DRWAV_INLINE drwav_int16 drwav__bswap_s16(drwav_int16 n) +{ + return (drwav_int16)drwav__bswap16((drwav_uint16)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s16(drwav_int16* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s16(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_s24(drwav_uint8* p) +{ + drwav_uint8 t; + t = p[0]; + p[0] = p[2]; + p[2] = t; +} + +static DRWAV_INLINE void drwav__bswap_samples_s24(drwav_uint8* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + drwav_uint8* pSample = pSamples + (iSample*3); + drwav__bswap_s24(pSample); + } +} + + +static DRWAV_INLINE drwav_int32 drwav__bswap_s32(drwav_int32 n) +{ + return (drwav_int32)drwav__bswap32((drwav_uint32)n); +} + +static DRWAV_INLINE void drwav__bswap_samples_s32(drwav_int32* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_s32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE float drwav__bswap_f32(float n) +{ + union { + drwav_uint32 i; + float f; + } x; + x.f = n; + x.i = drwav__bswap32(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f32(float* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f32(pSamples[iSample]); + } +} + + +static DRWAV_INLINE double drwav__bswap_f64(double n) +{ + union { + drwav_uint64 i; + double f; + } x; + x.f = n; + x.i = drwav__bswap64(x.i); + + return x.f; +} + +static DRWAV_INLINE void drwav__bswap_samples_f64(double* pSamples, drwav_uint64 sampleCount) +{ + drwav_uint64 iSample; + for (iSample = 0; iSample < sampleCount; iSample += 1) { + pSamples[iSample] = drwav__bswap_f64(pSamples[iSample]); + } +} + + +static DRWAV_INLINE void drwav__bswap_samples_pcm(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + /* Assumes integer PCM. Floating point PCM is done in drwav__bswap_samples_ieee(). */ + switch (bytesPerSample) + { + case 2: /* s16, s12 (loosely packed) */ + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + case 3: /* s24 */ + { + drwav__bswap_samples_s24((drwav_uint8*)pSamples, sampleCount); + } break; + case 4: /* s32 */ + { + drwav__bswap_samples_s32((drwav_int32*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples_ieee(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample) +{ + switch (bytesPerSample) + { + #if 0 /* Contributions welcome for f16 support. */ + case 2: /* f16 */ + { + drwav__bswap_samples_f16((drwav_float16*)pSamples, sampleCount); + } break; + #endif + case 4: /* f32 */ + { + drwav__bswap_samples_f32((float*)pSamples, sampleCount); + } break; + case 8: /* f64 */ + { + drwav__bswap_samples_f64((double*)pSamples, sampleCount); + } break; + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + +static DRWAV_INLINE void drwav__bswap_samples(void* pSamples, drwav_uint64 sampleCount, drwav_uint32 bytesPerSample, drwav_uint16 format) +{ + switch (format) + { + case DR_WAVE_FORMAT_PCM: + { + drwav__bswap_samples_pcm(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_IEEE_FLOAT: + { + drwav__bswap_samples_ieee(pSamples, sampleCount, bytesPerSample); + } break; + + case DR_WAVE_FORMAT_ALAW: + case DR_WAVE_FORMAT_MULAW: + { + drwav__bswap_samples_s16((drwav_int16*)pSamples, sampleCount); + } break; + + case DR_WAVE_FORMAT_ADPCM: + case DR_WAVE_FORMAT_DVI_ADPCM: + default: + { + /* Unsupported format. */ + DRWAV_ASSERT(DRWAV_FALSE); + } break; + } +} + + +static void* drwav__malloc_default(size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_MALLOC(sz); +} + +static void* drwav__realloc_default(void* p, size_t sz, void* pUserData) +{ + (void)pUserData; + return DRWAV_REALLOC(p, sz); +} + +static void drwav__free_default(void* p, void* pUserData) +{ + (void)pUserData; + DRWAV_FREE(p); +} + + +static void* drwav__malloc_from_callbacks(size_t sz, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onMalloc != NULL) { + return pAllocationCallbacks->onMalloc(sz, pAllocationCallbacks->pUserData); + } + + /* Try using realloc(). */ + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(NULL, sz, pAllocationCallbacks->pUserData); + } + + return NULL; +} + +static void* drwav__realloc_from_callbacks(void* p, size_t szNew, size_t szOld, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks == NULL) { + return NULL; + } + + if (pAllocationCallbacks->onRealloc != NULL) { + return pAllocationCallbacks->onRealloc(p, szNew, pAllocationCallbacks->pUserData); + } + + /* Try emulating realloc() in terms of malloc()/free(). */ + if (pAllocationCallbacks->onMalloc != NULL && pAllocationCallbacks->onFree != NULL) { + void* p2; + + p2 = pAllocationCallbacks->onMalloc(szNew, pAllocationCallbacks->pUserData); + if (p2 == NULL) { + return NULL; + } + + if (p != NULL) { + DRWAV_COPY_MEMORY(p2, p, szOld); + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } + + return p2; + } + + return NULL; +} + +static void drwav__free_from_callbacks(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (p == NULL || pAllocationCallbacks == NULL) { + return; + } + + if (pAllocationCallbacks->onFree != NULL) { + pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); + } +} + + +static drwav_allocation_callbacks drwav_copy_allocation_callbacks_or_defaults(const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + /* Copy. */ + return *pAllocationCallbacks; + } else { + /* Defaults. */ + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = NULL; + allocationCallbacks.onMalloc = drwav__malloc_default; + allocationCallbacks.onRealloc = drwav__realloc_default; + allocationCallbacks.onFree = drwav__free_default; + return allocationCallbacks; + } +} + + +static DRWAV_INLINE drwav_bool32 drwav__is_compressed_format_tag(drwav_uint16 formatTag) +{ + return + formatTag == DR_WAVE_FORMAT_ADPCM || + formatTag == DR_WAVE_FORMAT_DVI_ADPCM; +} + +static unsigned int drwav__chunk_padding_size_riff(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 2); +} + +static unsigned int drwav__chunk_padding_size_w64(drwav_uint64 chunkSize) +{ + return (unsigned int)(chunkSize % 8); +} + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 samplesToRead, drwav_int16* pBufferOut); +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount); + +static drwav_result drwav__read_chunk_header(drwav_read_proc onRead, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_chunk_header* pHeaderOut) +{ + if (container == drwav_container_riff || container == drwav_container_rf64) { + drwav_uint8 sizeInBytes[4]; + + if (onRead(pUserData, pHeaderOut->id.fourcc, 4) != 4) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 4) != 4) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u32(sizeInBytes); + pHeaderOut->paddingSize = drwav__chunk_padding_size_riff(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 8; + } else { + drwav_uint8 sizeInBytes[8]; + + if (onRead(pUserData, pHeaderOut->id.guid, 16) != 16) { + return DRWAV_AT_END; + } + + if (onRead(pUserData, sizeInBytes, 8) != 8) { + return DRWAV_INVALID_FILE; + } + + pHeaderOut->sizeInBytes = drwav__bytes_to_u64(sizeInBytes) - 24; /* <-- Subtract 24 because w64 includes the size of the header. */ + pHeaderOut->paddingSize = drwav__chunk_padding_size_w64(pHeaderOut->sizeInBytes); + *pRunningBytesReadOut += 24; + } + + return DRWAV_SUCCESS; +} + +static drwav_bool32 drwav__seek_forward(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + drwav_uint64 bytesRemainingToSeek = offset; + while (bytesRemainingToSeek > 0) { + if (bytesRemainingToSeek > 0x7FFFFFFF) { + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek -= 0x7FFFFFFF; + } else { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + bytesRemainingToSeek = 0; + } + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav__seek_from_start(drwav_seek_proc onSeek, drwav_uint64 offset, void* pUserData) +{ + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_start); + } + + /* Larger than 32-bit seek. */ + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + + for (;;) { + if (offset <= 0x7FFFFFFF) { + return onSeek(pUserData, (int)offset, drwav_seek_origin_current); + } + + if (!onSeek(pUserData, 0x7FFFFFFF, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + offset -= 0x7FFFFFFF; + } + + /* Should never get here. */ + /*return DRWAV_TRUE; */ +} + + +static drwav_bool32 drwav__read_fmt(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, drwav_container container, drwav_uint64* pRunningBytesReadOut, drwav_fmt* fmtOut) +{ + drwav_chunk_header header; + drwav_uint8 fmt[16]; + + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + + /* Skip non-fmt chunks. */ + while (((container == drwav_container_riff || container == drwav_container_rf64) && !drwav__fourcc_equal(header.id.fourcc, "fmt ")) || (container == drwav_container_w64 && !drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT))) { + if (!drwav__seek_forward(onSeek, header.sizeInBytes + header.paddingSize, pUserData)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.sizeInBytes + header.paddingSize; + + /* Try the next header. */ + if (drwav__read_chunk_header(onRead, pUserData, container, pRunningBytesReadOut, &header) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + } + + + /* Validation. */ + if (container == drwav_container_riff || container == drwav_container_rf64) { + if (!drwav__fourcc_equal(header.id.fourcc, "fmt ")) { + return DRWAV_FALSE; + } + } else { + if (!drwav__guid_equal(header.id.guid, drwavGUID_W64_FMT)) { + return DRWAV_FALSE; + } + } + + + if (onRead(pUserData, fmt, sizeof(fmt)) != sizeof(fmt)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += sizeof(fmt); + + fmtOut->formatTag = drwav__bytes_to_u16(fmt + 0); + fmtOut->channels = drwav__bytes_to_u16(fmt + 2); + fmtOut->sampleRate = drwav__bytes_to_u32(fmt + 4); + fmtOut->avgBytesPerSec = drwav__bytes_to_u32(fmt + 8); + fmtOut->blockAlign = drwav__bytes_to_u16(fmt + 12); + fmtOut->bitsPerSample = drwav__bytes_to_u16(fmt + 14); + + fmtOut->extendedSize = 0; + fmtOut->validBitsPerSample = 0; + fmtOut->channelMask = 0; + memset(fmtOut->subFormat, 0, sizeof(fmtOut->subFormat)); + + if (header.sizeInBytes > 16) { + drwav_uint8 fmt_cbSize[2]; + int bytesReadSoFar = 0; + + if (onRead(pUserData, fmt_cbSize, sizeof(fmt_cbSize)) != sizeof(fmt_cbSize)) { + return DRWAV_FALSE; /* Expecting more data. */ + } + *pRunningBytesReadOut += sizeof(fmt_cbSize); + + bytesReadSoFar = 18; + + fmtOut->extendedSize = drwav__bytes_to_u16(fmt_cbSize); + if (fmtOut->extendedSize > 0) { + /* Simple validation. */ + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + if (fmtOut->extendedSize != 22) { + return DRWAV_FALSE; + } + } + + if (fmtOut->formatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + drwav_uint8 fmtext[22]; + if (onRead(pUserData, fmtext, fmtOut->extendedSize) != fmtOut->extendedSize) { + return DRWAV_FALSE; /* Expecting more data. */ + } + + fmtOut->validBitsPerSample = drwav__bytes_to_u16(fmtext + 0); + fmtOut->channelMask = drwav__bytes_to_u32(fmtext + 2); + drwav__bytes_to_guid(fmtext + 6, fmtOut->subFormat); + } else { + if (!onSeek(pUserData, fmtOut->extendedSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + } + *pRunningBytesReadOut += fmtOut->extendedSize; + + bytesReadSoFar += fmtOut->extendedSize; + } + + /* Seek past any leftover bytes. For w64 the leftover will be defined based on the chunk size. */ + if (!onSeek(pUserData, (int)(header.sizeInBytes - bytesReadSoFar), drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += (header.sizeInBytes - bytesReadSoFar); + } + + if (header.paddingSize > 0) { + if (!onSeek(pUserData, header.paddingSize, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + *pRunningBytesReadOut += header.paddingSize; + } + + return DRWAV_TRUE; +} + + +static size_t drwav__on_read(drwav_read_proc onRead, void* pUserData, void* pBufferOut, size_t bytesToRead, drwav_uint64* pCursor) +{ + size_t bytesRead; + + DRWAV_ASSERT(onRead != NULL); + DRWAV_ASSERT(pCursor != NULL); + + bytesRead = onRead(pUserData, pBufferOut, bytesToRead); + *pCursor += bytesRead; + return bytesRead; +} + +#if 0 +static drwav_bool32 drwav__on_seek(drwav_seek_proc onSeek, void* pUserData, int offset, drwav_seek_origin origin, drwav_uint64* pCursor) +{ + DRWAV_ASSERT(onSeek != NULL); + DRWAV_ASSERT(pCursor != NULL); + + if (!onSeek(pUserData, offset, origin)) { + return DRWAV_FALSE; + } + + if (origin == drwav_seek_origin_start) { + *pCursor = offset; + } else { + *pCursor += offset; + } + + return DRWAV_TRUE; +} +#endif + + + +static drwav_uint32 drwav_get_bytes_per_pcm_frame(drwav* pWav) +{ + /* + The bytes per frame is a bit ambiguous. It can be either be based on the bits per sample, or the block align. The way I'm doing it here + is that if the bits per sample is a multiple of 8, use floor(bitsPerSample*channels/8), otherwise fall back to the block align. + */ + if ((pWav->bitsPerSample & 0x7) == 0) { + /* Bits per sample is a multiple of 8. */ + return (pWav->bitsPerSample * pWav->fmt.channels) >> 3; + } else { + return pWav->fmt.blockAlign; + } +} + +DRWAV_API drwav_uint16 drwav_fmt_get_format(const drwav_fmt* pFMT) +{ + if (pFMT == NULL) { + return 0; + } + + if (pFMT->formatTag != DR_WAVE_FORMAT_EXTENSIBLE) { + return pFMT->formatTag; + } else { + return drwav__bytes_to_u16(pFMT->subFormat); /* Only the first two bytes are required. */ + } +} + +static drwav_bool32 drwav_preinit(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pReadSeekUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onRead == NULL || onSeek == NULL) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onRead = onRead; + pWav->onSeek = onSeek; + pWav->pUserData = pReadSeekUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init__internal(drwav* pWav, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags) +{ + /* This function assumes drwav_preinit() has been called beforehand. */ + + drwav_uint64 cursor; /* <-- Keeps track of the byte position so we can seek to specific locations. */ + drwav_bool32 sequential; + drwav_uint8 riff[4]; + drwav_fmt fmt; + unsigned short translatedFormatTag; + drwav_bool32 foundDataChunk; + drwav_uint64 dataChunkSize = 0; /* <-- Important! Don't explicitly set this to 0 anywhere else. Calculation of the size of the data chunk is performed in different paths depending on the container. */ + drwav_uint64 sampleCountFromFactChunk = 0; /* Same as dataChunkSize - make sure this is the only place this is initialized to 0. */ + drwav_uint64 chunkSize; + + cursor = 0; + sequential = (flags & DRWAV_SEQUENTIAL) != 0; + + /* The first 4 bytes should be the RIFF identifier. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff, sizeof(riff), &cursor) != sizeof(riff)) { + return DRWAV_FALSE; + } + + /* + The first 4 bytes can be used to identify the container. For RIFF files it will start with "RIFF" and for + w64 it will start with "riff". + */ + if (drwav__fourcc_equal(riff, "RIFF")) { + pWav->container = drwav_container_riff; + } else if (drwav__fourcc_equal(riff, "riff")) { + int i; + drwav_uint8 riff2[12]; + + pWav->container = drwav_container_w64; + + /* Check the rest of the GUID for validity. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, riff2, sizeof(riff2), &cursor) != sizeof(riff2)) { + return DRWAV_FALSE; + } + + for (i = 0; i < 12; ++i) { + if (riff2[i] != drwavGUID_W64_RIFF[i+4]) { + return DRWAV_FALSE; + } + } + } else if (drwav__fourcc_equal(riff, "RF64")) { + pWav->container = drwav_container_rf64; + } else { + return DRWAV_FALSE; /* Unknown or unsupported container. */ + } + + + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + drwav_uint8 chunkSizeBytes[4]; + drwav_uint8 wave[4]; + + /* RIFF/WAVE */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (pWav->container == drwav_container_riff) { + if (drwav__bytes_to_u32(chunkSizeBytes) < 36) { + return DRWAV_FALSE; /* Chunk size should always be at least 36 bytes. */ + } + } else { + if (drwav__bytes_to_u32(chunkSizeBytes) != 0xFFFFFFFF) { + return DRWAV_FALSE; /* Chunk size should always be set to -1/0xFFFFFFFF for RF64. The actual size is retrieved later. */ + } + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(wave, "WAVE")) { + return DRWAV_FALSE; /* Expecting "WAVE". */ + } + } else { + drwav_uint8 chunkSizeBytes[8]; + drwav_uint8 wave[16]; + + /* W64 */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, chunkSizeBytes, sizeof(chunkSizeBytes), &cursor) != sizeof(chunkSizeBytes)) { + return DRWAV_FALSE; + } + + if (drwav__bytes_to_u64(chunkSizeBytes) < 80) { + return DRWAV_FALSE; + } + + if (drwav__on_read(pWav->onRead, pWav->pUserData, wave, sizeof(wave), &cursor) != sizeof(wave)) { + return DRWAV_FALSE; + } + + if (!drwav__guid_equal(wave, drwavGUID_W64_WAVE)) { + return DRWAV_FALSE; + } + } + + + /* For RF64, the "ds64" chunk must come next, before the "fmt " chunk. */ + if (pWav->container == drwav_container_rf64) { + drwav_uint8 sizeBytes[8]; + drwav_uint64 bytesRemainingInChunk; + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + if (!drwav__fourcc_equal(header.id.fourcc, "ds64")) { + return DRWAV_FALSE; /* Expecting "ds64". */ + } + + bytesRemainingInChunk = header.sizeInBytes + header.paddingSize; + + /* We don't care about the size of the RIFF chunk - skip it. */ + if (!drwav__seek_forward(pWav->onSeek, 8, pWav->pUserData)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + cursor += 8; + + + /* Next 8 bytes is the size of the "data" chunk. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + dataChunkSize = drwav__bytes_to_u64(sizeBytes); + + + /* Next 8 bytes is the same count which we would usually derived from the FACT chunk if it was available. */ + if (drwav__on_read(pWav->onRead, pWav->pUserData, sizeBytes, sizeof(sizeBytes), &cursor) != sizeof(sizeBytes)) { + return DRWAV_FALSE; + } + bytesRemainingInChunk -= 8; + sampleCountFromFactChunk = drwav__bytes_to_u64(sizeBytes); + + + /* Skip over everything else. */ + if (!drwav__seek_forward(pWav->onSeek, bytesRemainingInChunk, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor += bytesRemainingInChunk; + } + + + /* The next bytes should be the "fmt " chunk. */ + if (!drwav__read_fmt(pWav->onRead, pWav->onSeek, pWav->pUserData, pWav->container, &cursor, &fmt)) { + return DRWAV_FALSE; /* Failed to read the "fmt " chunk. */ + } + + /* Basic validation. */ + if ((fmt.sampleRate == 0 || fmt.sampleRate > DRWAV_MAX_SAMPLE_RATE) || + (fmt.channels == 0 || fmt.channels > DRWAV_MAX_CHANNELS) || + (fmt.bitsPerSample == 0 || fmt.bitsPerSample > DRWAV_MAX_BITS_PER_SAMPLE) || + fmt.blockAlign == 0) { + return DRWAV_FALSE; /* Probably an invalid WAV file. */ + } + + + /* Translate the internal format. */ + translatedFormatTag = fmt.formatTag; + if (translatedFormatTag == DR_WAVE_FORMAT_EXTENSIBLE) { + translatedFormatTag = drwav__bytes_to_u16(fmt.subFormat + 0); + } + + + /* + We need to enumerate over each chunk for two reasons: + 1) The "data" chunk may not be the next one + 2) We may want to report each chunk back to the client + + In order to correctly report each chunk back to the client we will need to keep looping until the end of the file. + */ + foundDataChunk = DRWAV_FALSE; + + /* The next chunk we care about is the "data" chunk. This is not necessarily the next chunk so we'll need to loop. */ + for (;;) + { + drwav_chunk_header header; + drwav_result result = drwav__read_chunk_header(pWav->onRead, pWav->pUserData, pWav->container, &cursor, &header); + if (result != DRWAV_SUCCESS) { + if (!foundDataChunk) { + return DRWAV_FALSE; + } else { + break; /* Probably at the end of the file. Get out of the loop. */ + } + } + + /* Tell the client about this chunk. */ + if (!sequential && onChunk != NULL) { + drwav_uint64 callbackBytesRead = onChunk(pChunkUserData, pWav->onRead, pWav->onSeek, pWav->pUserData, &header, pWav->container, &fmt); + + /* + dr_wav may need to read the contents of the chunk, so we now need to seek back to the position before + we called the callback. + */ + if (callbackBytesRead > 0) { + if (!drwav__seek_from_start(pWav->onSeek, cursor, pWav->pUserData)) { + return DRWAV_FALSE; + } + } + } + + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + chunkSize = header.sizeInBytes; + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "data")) { + foundDataChunk = DRWAV_TRUE; + if (pWav->container != drwav_container_rf64) { /* The data chunk size for RF64 will always be set to 0xFFFFFFFF here. It was set to it's true value earlier. */ + dataChunkSize = chunkSize; + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_DATA)) { + foundDataChunk = DRWAV_TRUE; + dataChunkSize = chunkSize; + } + } + + /* + If at this point we have found the data chunk and we're running in sequential mode, we need to break out of this loop. The reason for + this is that we would otherwise require a backwards seek which sequential mode forbids. + */ + if (foundDataChunk && sequential) { + break; + } + + /* Optional. Get the total sample count from the FACT chunk. This is useful for compressed formats. */ + if (pWav->container == drwav_container_riff) { + if (drwav__fourcc_equal(header.id.fourcc, "fact")) { + drwav_uint32 sampleCount; + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCount, 4, &cursor) != 4) { + return DRWAV_FALSE; + } + chunkSize -= 4; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + + /* + The sample count in the "fact" chunk is either unreliable, or I'm not understanding it properly. For now I am only enabling this + for Microsoft ADPCM formats. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + sampleCountFromFactChunk = sampleCount; + } else { + sampleCountFromFactChunk = 0; + } + } + } else if (pWav->container == drwav_container_w64) { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_FACT)) { + if (drwav__on_read(pWav->onRead, pWav->pUserData, &sampleCountFromFactChunk, 8, &cursor) != 8) { + return DRWAV_FALSE; + } + chunkSize -= 8; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + } else if (pWav->container == drwav_container_rf64) { + /* We retrieved the sample count from the ds64 chunk earlier so no need to do that here. */ + } + + /* "smpl" chunk. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + if (drwav__fourcc_equal(header.id.fourcc, "smpl")) { + drwav_uint8 smplHeaderData[36]; /* 36 = size of the smpl header section, not including the loop data. */ + if (chunkSize >= sizeof(smplHeaderData)) { + drwav_uint64 bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplHeaderData, sizeof(smplHeaderData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplHeaderData)) { + drwav_uint32 iLoop; + + pWav->smpl.manufacturer = drwav__bytes_to_u32(smplHeaderData+0); + pWav->smpl.product = drwav__bytes_to_u32(smplHeaderData+4); + pWav->smpl.samplePeriod = drwav__bytes_to_u32(smplHeaderData+8); + pWav->smpl.midiUnityNotes = drwav__bytes_to_u32(smplHeaderData+12); + pWav->smpl.midiPitchFraction = drwav__bytes_to_u32(smplHeaderData+16); + pWav->smpl.smpteFormat = drwav__bytes_to_u32(smplHeaderData+20); + pWav->smpl.smpteOffset = drwav__bytes_to_u32(smplHeaderData+24); + pWav->smpl.numSampleLoops = drwav__bytes_to_u32(smplHeaderData+28); + pWav->smpl.samplerData = drwav__bytes_to_u32(smplHeaderData+32); + + for (iLoop = 0; iLoop < pWav->smpl.numSampleLoops && iLoop < drwav_countof(pWav->smpl.loops); ++iLoop) { + drwav_uint8 smplLoopData[24]; /* 24 = size of a loop section in the smpl chunk. */ + bytesJustRead = drwav__on_read(pWav->onRead, pWav->pUserData, smplLoopData, sizeof(smplLoopData), &cursor); + chunkSize -= bytesJustRead; + + if (bytesJustRead == sizeof(smplLoopData)) { + pWav->smpl.loops[iLoop].cuePointId = drwav__bytes_to_u32(smplLoopData+0); + pWav->smpl.loops[iLoop].type = drwav__bytes_to_u32(smplLoopData+4); + pWav->smpl.loops[iLoop].start = drwav__bytes_to_u32(smplLoopData+8); + pWav->smpl.loops[iLoop].end = drwav__bytes_to_u32(smplLoopData+12); + pWav->smpl.loops[iLoop].fraction = drwav__bytes_to_u32(smplLoopData+16); + pWav->smpl.loops[iLoop].playCount = drwav__bytes_to_u32(smplLoopData+20); + } else { + break; /* Break from the smpl loop for loop. */ + } + } + } + } else { + /* Looks like invalid data. Ignore the chunk. */ + } + } + } else { + if (drwav__guid_equal(header.id.guid, drwavGUID_W64_SMPL)) { + /* + This path will be hit when a W64 WAV file contains a smpl chunk. I don't have a sample file to test this path, so a contribution + is welcome to add support for this. + */ + } + } + + /* Make sure we seek past the padding. */ + chunkSize += header.paddingSize; + if (!drwav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData)) { + break; + } + cursor += chunkSize; + + if (!foundDataChunk) { + pWav->dataChunkDataPos = cursor; + } + } + + /* If we haven't found a data chunk, return an error. */ + if (!foundDataChunk) { + return DRWAV_FALSE; + } + + /* We may have moved passed the data chunk. If so we need to move back. If running in sequential mode we can assume we are already sitting on the data chunk. */ + if (!sequential) { + if (!drwav__seek_from_start(pWav->onSeek, pWav->dataChunkDataPos, pWav->pUserData)) { + return DRWAV_FALSE; + } + cursor = pWav->dataChunkDataPos; + } + + + /* At this point we should be sitting on the first byte of the raw audio data. */ + + pWav->fmt = fmt; + pWav->sampleRate = fmt.sampleRate; + pWav->channels = fmt.channels; + pWav->bitsPerSample = fmt.bitsPerSample; + pWav->bytesRemaining = dataChunkSize; + pWav->translatedFormatTag = translatedFormatTag; + pWav->dataChunkDataSize = dataChunkSize; + + if (sampleCountFromFactChunk != 0) { + pWav->totalPCMFrameCount = sampleCountFromFactChunk; + } else { + pWav->totalPCMFrameCount = dataChunkSize / drwav_get_bytes_per_pcm_frame(pWav); + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (6*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 totalBlockHeaderSizeInBytes; + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + + /* Make sure any trailing partial block is accounted for. */ + if ((blockCount * fmt.blockAlign) < dataChunkSize) { + blockCount += 1; + } + + /* We decode two samples per byte. There will be blockCount headers in the data chunk. This is enough to know how to calculate the total PCM frame count. */ + totalBlockHeaderSizeInBytes = blockCount * (4*fmt.channels); + pWav->totalPCMFrameCount = ((dataChunkSize - totalBlockHeaderSizeInBytes) * 2) / fmt.channels; + + /* The header includes a decoded sample for each channel which acts as the initial predictor sample. */ + pWav->totalPCMFrameCount += blockCount; + } + } + + /* Some formats only support a certain number of channels. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM || pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + if (pWav->channels > 2) { + return DRWAV_FALSE; + } + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + I use libsndfile as a benchmark for testing, however in the version I'm using (from the Windows installer on the libsndfile website), + it appears the total sample count libsndfile uses for MS-ADPCM is incorrect. It would seem they are computing the total sample count + from the number of blocks, however this results in the inclusion of extra silent samples at the end of the last block. The correct + way to know the total sample count is to inspect the "fact" chunk, which should always be present for compressed formats, and should + always include the sample count. This little block of code below is only used to emulate the libsndfile logic so I can properly run my + correctness tests against libsndfile, and is disabled by default. + */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (6*pWav->channels))) * 2)) / fmt.channels; /* x2 because two samples per byte. */ + } + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + drwav_uint64 blockCount = dataChunkSize / fmt.blockAlign; + pWav->totalPCMFrameCount = (((blockCount * (fmt.blockAlign - (4*pWav->channels))) * 2) + (blockCount * pWav->channels)) / fmt.channels; + } +#endif + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_ex(drwav* pWav, drwav_read_proc onRead, drwav_seek_proc onSeek, drwav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_uint32 drwav__riff_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return (drwav_uint32)chunkSize; /* Safe cast due to the clamp above. */ +} + +static drwav_uint32 drwav__data_chunk_size_riff(drwav_uint64 dataChunkSize) +{ + if (dataChunkSize <= 0xFFFFFFFFUL) { + return (drwav_uint32)dataChunkSize; + } else { + return 0xFFFFFFFFUL; + } +} + +static drwav_uint64 drwav__riff_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 dataSubchunkPaddingSize = drwav__chunk_padding_size_w64(dataChunkSize); + + return 80 + 24 + dataChunkSize + dataSubchunkPaddingSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__data_chunk_size_w64(drwav_uint64 dataChunkSize) +{ + return 24 + dataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ +} + +static drwav_uint64 drwav__riff_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + drwav_uint64 chunkSize = 4 + 36 + 24 + dataChunkSize + drwav__chunk_padding_size_riff(dataChunkSize); /* 4 = "WAVE". 36 = "ds64" chunk. 24 = "fmt " chunk. */ + if (chunkSize > 0xFFFFFFFFUL) { + chunkSize = 0xFFFFFFFFUL; + } + + return chunkSize; +} + +static drwav_uint64 drwav__data_chunk_size_rf64(drwav_uint64 dataChunkSize) +{ + return dataChunkSize; +} + + +static size_t drwav__write(drwav* pWav, const void* pData, size_t dataSize) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + /* Generic write. Assumes no byte reordering required. */ + return pWav->onWrite(pWav->pUserData, pData, dataSize); +} + +static size_t drwav__write_u16ne_to_le(drwav* pWav, drwav_uint16 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap16(value); + } + + return drwav__write(pWav, &value, 2); +} + +static size_t drwav__write_u32ne_to_le(drwav* pWav, drwav_uint32 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap32(value); + } + + return drwav__write(pWav, &value, 4); +} + +static size_t drwav__write_u64ne_to_le(drwav* pWav, drwav_uint64 value) +{ + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->onWrite != NULL); + + if (!drwav__is_little_endian()) { + value = drwav__bswap64(value); + } + + return drwav__write(pWav, &value, 8); +} + + +static drwav_bool32 drwav_preinit_write(drwav* pWav, const drwav_data_format* pFormat, drwav_bool32 isSequential, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pWav == NULL || onWrite == NULL) { + return DRWAV_FALSE; + } + + if (!isSequential && onSeek == NULL) { + return DRWAV_FALSE; /* <-- onSeek is required when in non-sequential mode. */ + } + + /* Not currently supporting compressed formats. Will need to add support for the "fact" chunk before we enable this. */ + if (pFormat->format == DR_WAVE_FORMAT_EXTENSIBLE) { + return DRWAV_FALSE; + } + if (pFormat->format == DR_WAVE_FORMAT_ADPCM || pFormat->format == DR_WAVE_FORMAT_DVI_ADPCM) { + return DRWAV_FALSE; + } + + DRWAV_ZERO_MEMORY(pWav, sizeof(*pWav)); + pWav->onWrite = onWrite; + pWav->onSeek = onSeek; + pWav->pUserData = pUserData; + pWav->allocationCallbacks = drwav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); + + if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { + return DRWAV_FALSE; /* Invalid allocation callbacks. */ + } + + pWav->fmt.formatTag = (drwav_uint16)pFormat->format; + pWav->fmt.channels = (drwav_uint16)pFormat->channels; + pWav->fmt.sampleRate = pFormat->sampleRate; + pWav->fmt.avgBytesPerSec = (drwav_uint32)((pFormat->bitsPerSample * pFormat->sampleRate * pFormat->channels) / 8); + pWav->fmt.blockAlign = (drwav_uint16)((pFormat->channels * pFormat->bitsPerSample) / 8); + pWav->fmt.bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->fmt.extendedSize = 0; + pWav->isSequentialWrite = isSequential; + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_write__internal(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* The function assumes drwav_preinit_write() was called beforehand. */ + + size_t runningPos = 0; + drwav_uint64 initialDataChunkSize = 0; + drwav_uint64 chunkSizeFMT; + + /* + The initial values for the "RIFF" and "data" chunks depends on whether or not we are initializing in sequential mode or not. In + sequential mode we set this to its final values straight away since they can be calculated from the total sample count. In non- + sequential mode we initialize it all to zero and fill it out in drwav_uninit() using a backwards seek. + */ + if (pWav->isSequentialWrite) { + initialDataChunkSize = (totalSampleCount * pWav->fmt.bitsPerSample) / 8; + + /* + The RIFF container has a limit on the number of samples. drwav is not allowing this. There's no practical limits for Wave64 + so for the sake of simplicity I'm not doing any validation for that. + */ + if (pFormat->container == drwav_container_riff) { + if (initialDataChunkSize > (0xFFFFFFFFUL - 36)) { + return DRWAV_FALSE; /* Not enough room to store every sample. */ + } + } + } + + pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; + + + /* "RIFF" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeRIFF = 28 + (drwav_uint32)initialDataChunkSize; /* +28 = "WAVE" + [sizeof "fmt " chunk] */ + runningPos += drwav__write(pWav, "RIFF", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, "WAVE", 4); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeRIFF = 80 + 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_RIFF, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeRIFF); + runningPos += drwav__write(pWav, drwavGUID_W64_WAVE, 16); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "RF64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always 0xFFFFFFFF for RF64. Set to a proper value in the "ds64" chunk. */ + runningPos += drwav__write(pWav, "WAVE", 4); + } + + + /* "ds64" chunk (RF64 only). */ + if (pFormat->container == drwav_container_rf64) { + drwav_uint32 initialds64ChunkSize = 28; /* 28 = [Size of RIFF (8 bytes)] + [Size of DATA (8 bytes)] + [Sample Count (8 bytes)] + [Table Length (4 bytes)]. Table length always set to 0. */ + drwav_uint64 initialRiffChunkSize = 8 + initialds64ChunkSize + initialDataChunkSize; /* +8 for the ds64 header. */ + + runningPos += drwav__write(pWav, "ds64", 4); + runningPos += drwav__write_u32ne_to_le(pWav, initialds64ChunkSize); /* Size of ds64. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialRiffChunkSize); /* Size of RIFF. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, initialDataChunkSize); /* Size of DATA. Set to true value at the end. */ + runningPos += drwav__write_u64ne_to_le(pWav, totalSampleCount); /* Sample count. */ + runningPos += drwav__write_u32ne_to_le(pWav, 0); /* Table length. Always set to zero in our case since we're not doing any other chunks than "DATA". */ + } + + + /* "fmt " chunk. */ + if (pFormat->container == drwav_container_riff || pFormat->container == drwav_container_rf64) { + chunkSizeFMT = 16; + runningPos += drwav__write(pWav, "fmt ", 4); + runningPos += drwav__write_u32ne_to_le(pWav, (drwav_uint32)chunkSizeFMT); + } else if (pFormat->container == drwav_container_w64) { + chunkSizeFMT = 40; + runningPos += drwav__write(pWav, drwavGUID_W64_FMT, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeFMT); + } + + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.formatTag); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.channels); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.sampleRate); + runningPos += drwav__write_u32ne_to_le(pWav, pWav->fmt.avgBytesPerSec); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.blockAlign); + runningPos += drwav__write_u16ne_to_le(pWav, pWav->fmt.bitsPerSample); + + pWav->dataChunkDataPos = runningPos; + + /* "data" chunk. */ + if (pFormat->container == drwav_container_riff) { + drwav_uint32 chunkSizeDATA = (drwav_uint32)initialDataChunkSize; + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_w64) { + drwav_uint64 chunkSizeDATA = 24 + initialDataChunkSize; /* +24 because W64 includes the size of the GUID and size fields. */ + runningPos += drwav__write(pWav, drwavGUID_W64_DATA, 16); + runningPos += drwav__write_u64ne_to_le(pWav, chunkSizeDATA); + } else if (pFormat->container == drwav_container_rf64) { + runningPos += drwav__write(pWav, "data", 4); + runningPos += drwav__write_u32ne_to_le(pWav, 0xFFFFFFFF); /* Always set to 0xFFFFFFFF for RF64. The true size of the data chunk is specified in the ds64 chunk. */ + } + + /* + The runningPos variable is incremented in the section above but is left unused which is causing some static analysis tools to detect it + as a dead store. I'm leaving this as-is for safety just in case I want to expand this function later to include other tags and want to + keep track of the running position for whatever reason. The line below should silence the static analysis tools. + */ + (void)runningPos; + + /* Set some properties for the client's convenience. */ + pWav->container = pFormat->container; + pWav->channels = (drwav_uint16)pFormat->channels; + pWav->sampleRate = pFormat->sampleRate; + pWav->bitsPerSample = (drwav_uint16)pFormat->bitsPerSample; + pWav->translatedFormatTag = (drwav_uint16)pFormat->format; + + return DRWAV_TRUE; +} + + +DRWAV_API drwav_bool32 drwav_init_write(drwav* pWav, const drwav_data_format* pFormat, drwav_write_proc onWrite, drwav_seek_proc onSeek, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_FALSE, onWrite, onSeek, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, 0); /* DRWAV_FALSE = Not Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (!drwav_preinit_write(pWav, pFormat, DRWAV_TRUE, onWrite, NULL, pUserData, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); /* DRWAV_TRUE = Sequential */ +} + +DRWAV_API drwav_bool32 drwav_init_write_sequential_pcm_frames(drwav* pWav, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, drwav_write_proc onWrite, void* pUserData, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_write_sequential(pWav, pFormat, totalPCMFrameCount*pFormat->channels, onWrite, pUserData, pAllocationCallbacks); +} + +DRWAV_API drwav_uint64 drwav_target_write_size_bytes(const drwav_data_format* pFormat, drwav_uint64 totalSampleCount) +{ + /* Casting totalSampleCount to drwav_int64 for VC6 compatibility. No issues in practice because nobody is going to exhaust the whole 63 bits. */ + drwav_uint64 targetDataSizeBytes = (drwav_uint64)((drwav_int64)totalSampleCount * pFormat->channels * pFormat->bitsPerSample/8.0); + drwav_uint64 riffChunkSizeBytes; + drwav_uint64 fileSizeBytes = 0; + + if (pFormat->container == drwav_container_riff) { + riffChunkSizeBytes = drwav__riff_chunk_size_riff(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } else if (pFormat->container == drwav_container_w64) { + riffChunkSizeBytes = drwav__riff_chunk_size_w64(targetDataSizeBytes); + fileSizeBytes = riffChunkSizeBytes; + } else if (pFormat->container == drwav_container_rf64) { + riffChunkSizeBytes = drwav__riff_chunk_size_rf64(targetDataSizeBytes); + fileSizeBytes = (8 + riffChunkSizeBytes); /* +8 because WAV doesn't include the size of the ChunkID and ChunkSize fields. */ + } + + return fileSizeBytes; +} + + +#ifndef DR_WAV_NO_STDIO + +/* drwav_result_from_errno() is only used for fopen() and wfopen() so putting it inside DR_WAV_NO_STDIO for now. If something else needs this later we can move it out. */ +#include +static drwav_result drwav_result_from_errno(int e) +{ + switch (e) + { + case 0: return DRWAV_SUCCESS; + #ifdef EPERM + case EPERM: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ENOENT + case ENOENT: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ESRCH + case ESRCH: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EINTR + case EINTR: return DRWAV_INTERRUPT; + #endif + #ifdef EIO + case EIO: return DRWAV_IO_ERROR; + #endif + #ifdef ENXIO + case ENXIO: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef E2BIG + case E2BIG: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENOEXEC + case ENOEXEC: return DRWAV_INVALID_FILE; + #endif + #ifdef EBADF + case EBADF: return DRWAV_INVALID_FILE; + #endif + #ifdef ECHILD + case ECHILD: return DRWAV_ERROR; + #endif + #ifdef EAGAIN + case EAGAIN: return DRWAV_UNAVAILABLE; + #endif + #ifdef ENOMEM + case ENOMEM: return DRWAV_OUT_OF_MEMORY; + #endif + #ifdef EACCES + case EACCES: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EFAULT + case EFAULT: return DRWAV_BAD_ADDRESS; + #endif + #ifdef ENOTBLK + case ENOTBLK: return DRWAV_ERROR; + #endif + #ifdef EBUSY + case EBUSY: return DRWAV_BUSY; + #endif + #ifdef EEXIST + case EEXIST: return DRWAV_ALREADY_EXISTS; + #endif + #ifdef EXDEV + case EXDEV: return DRWAV_ERROR; + #endif + #ifdef ENODEV + case ENODEV: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef ENOTDIR + case ENOTDIR: return DRWAV_NOT_DIRECTORY; + #endif + #ifdef EISDIR + case EISDIR: return DRWAV_IS_DIRECTORY; + #endif + #ifdef EINVAL + case EINVAL: return DRWAV_INVALID_ARGS; + #endif + #ifdef ENFILE + case ENFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef EMFILE + case EMFILE: return DRWAV_TOO_MANY_OPEN_FILES; + #endif + #ifdef ENOTTY + case ENOTTY: return DRWAV_INVALID_OPERATION; + #endif + #ifdef ETXTBSY + case ETXTBSY: return DRWAV_BUSY; + #endif + #ifdef EFBIG + case EFBIG: return DRWAV_TOO_BIG; + #endif + #ifdef ENOSPC + case ENOSPC: return DRWAV_NO_SPACE; + #endif + #ifdef ESPIPE + case ESPIPE: return DRWAV_BAD_SEEK; + #endif + #ifdef EROFS + case EROFS: return DRWAV_ACCESS_DENIED; + #endif + #ifdef EMLINK + case EMLINK: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef EPIPE + case EPIPE: return DRWAV_BAD_PIPE; + #endif + #ifdef EDOM + case EDOM: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef ERANGE + case ERANGE: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EDEADLK + case EDEADLK: return DRWAV_DEADLOCK; + #endif + #ifdef ENAMETOOLONG + case ENAMETOOLONG: return DRWAV_PATH_TOO_LONG; + #endif + #ifdef ENOLCK + case ENOLCK: return DRWAV_ERROR; + #endif + #ifdef ENOSYS + case ENOSYS: return DRWAV_NOT_IMPLEMENTED; + #endif + #ifdef ENOTEMPTY + case ENOTEMPTY: return DRWAV_DIRECTORY_NOT_EMPTY; + #endif + #ifdef ELOOP + case ELOOP: return DRWAV_TOO_MANY_LINKS; + #endif + #ifdef ENOMSG + case ENOMSG: return DRWAV_NO_MESSAGE; + #endif + #ifdef EIDRM + case EIDRM: return DRWAV_ERROR; + #endif + #ifdef ECHRNG + case ECHRNG: return DRWAV_ERROR; + #endif + #ifdef EL2NSYNC + case EL2NSYNC: return DRWAV_ERROR; + #endif + #ifdef EL3HLT + case EL3HLT: return DRWAV_ERROR; + #endif + #ifdef EL3RST + case EL3RST: return DRWAV_ERROR; + #endif + #ifdef ELNRNG + case ELNRNG: return DRWAV_OUT_OF_RANGE; + #endif + #ifdef EUNATCH + case EUNATCH: return DRWAV_ERROR; + #endif + #ifdef ENOCSI + case ENOCSI: return DRWAV_ERROR; + #endif + #ifdef EL2HLT + case EL2HLT: return DRWAV_ERROR; + #endif + #ifdef EBADE + case EBADE: return DRWAV_ERROR; + #endif + #ifdef EBADR + case EBADR: return DRWAV_ERROR; + #endif + #ifdef EXFULL + case EXFULL: return DRWAV_ERROR; + #endif + #ifdef ENOANO + case ENOANO: return DRWAV_ERROR; + #endif + #ifdef EBADRQC + case EBADRQC: return DRWAV_ERROR; + #endif + #ifdef EBADSLT + case EBADSLT: return DRWAV_ERROR; + #endif + #ifdef EBFONT + case EBFONT: return DRWAV_INVALID_FILE; + #endif + #ifdef ENOSTR + case ENOSTR: return DRWAV_ERROR; + #endif + #ifdef ENODATA + case ENODATA: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ETIME + case ETIME: return DRWAV_TIMEOUT; + #endif + #ifdef ENOSR + case ENOSR: return DRWAV_NO_DATA_AVAILABLE; + #endif + #ifdef ENONET + case ENONET: return DRWAV_NO_NETWORK; + #endif + #ifdef ENOPKG + case ENOPKG: return DRWAV_ERROR; + #endif + #ifdef EREMOTE + case EREMOTE: return DRWAV_ERROR; + #endif + #ifdef ENOLINK + case ENOLINK: return DRWAV_ERROR; + #endif + #ifdef EADV + case EADV: return DRWAV_ERROR; + #endif + #ifdef ESRMNT + case ESRMNT: return DRWAV_ERROR; + #endif + #ifdef ECOMM + case ECOMM: return DRWAV_ERROR; + #endif + #ifdef EPROTO + case EPROTO: return DRWAV_ERROR; + #endif + #ifdef EMULTIHOP + case EMULTIHOP: return DRWAV_ERROR; + #endif + #ifdef EDOTDOT + case EDOTDOT: return DRWAV_ERROR; + #endif + #ifdef EBADMSG + case EBADMSG: return DRWAV_BAD_MESSAGE; + #endif + #ifdef EOVERFLOW + case EOVERFLOW: return DRWAV_TOO_BIG; + #endif + #ifdef ENOTUNIQ + case ENOTUNIQ: return DRWAV_NOT_UNIQUE; + #endif + #ifdef EBADFD + case EBADFD: return DRWAV_ERROR; + #endif + #ifdef EREMCHG + case EREMCHG: return DRWAV_ERROR; + #endif + #ifdef ELIBACC + case ELIBACC: return DRWAV_ACCESS_DENIED; + #endif + #ifdef ELIBBAD + case ELIBBAD: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBSCN + case ELIBSCN: return DRWAV_INVALID_FILE; + #endif + #ifdef ELIBMAX + case ELIBMAX: return DRWAV_ERROR; + #endif + #ifdef ELIBEXEC + case ELIBEXEC: return DRWAV_ERROR; + #endif + #ifdef EILSEQ + case EILSEQ: return DRWAV_INVALID_DATA; + #endif + #ifdef ERESTART + case ERESTART: return DRWAV_ERROR; + #endif + #ifdef ESTRPIPE + case ESTRPIPE: return DRWAV_ERROR; + #endif + #ifdef EUSERS + case EUSERS: return DRWAV_ERROR; + #endif + #ifdef ENOTSOCK + case ENOTSOCK: return DRWAV_NOT_SOCKET; + #endif + #ifdef EDESTADDRREQ + case EDESTADDRREQ: return DRWAV_NO_ADDRESS; + #endif + #ifdef EMSGSIZE + case EMSGSIZE: return DRWAV_TOO_BIG; + #endif + #ifdef EPROTOTYPE + case EPROTOTYPE: return DRWAV_BAD_PROTOCOL; + #endif + #ifdef ENOPROTOOPT + case ENOPROTOOPT: return DRWAV_PROTOCOL_UNAVAILABLE; + #endif + #ifdef EPROTONOSUPPORT + case EPROTONOSUPPORT: return DRWAV_PROTOCOL_NOT_SUPPORTED; + #endif + #ifdef ESOCKTNOSUPPORT + case ESOCKTNOSUPPORT: return DRWAV_SOCKET_NOT_SUPPORTED; + #endif + #ifdef EOPNOTSUPP + case EOPNOTSUPP: return DRWAV_INVALID_OPERATION; + #endif + #ifdef EPFNOSUPPORT + case EPFNOSUPPORT: return DRWAV_PROTOCOL_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EAFNOSUPPORT + case EAFNOSUPPORT: return DRWAV_ADDRESS_FAMILY_NOT_SUPPORTED; + #endif + #ifdef EADDRINUSE + case EADDRINUSE: return DRWAV_ALREADY_IN_USE; + #endif + #ifdef EADDRNOTAVAIL + case EADDRNOTAVAIL: return DRWAV_ERROR; + #endif + #ifdef ENETDOWN + case ENETDOWN: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETUNREACH + case ENETUNREACH: return DRWAV_NO_NETWORK; + #endif + #ifdef ENETRESET + case ENETRESET: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNABORTED + case ECONNABORTED: return DRWAV_NO_NETWORK; + #endif + #ifdef ECONNRESET + case ECONNRESET: return DRWAV_CONNECTION_RESET; + #endif + #ifdef ENOBUFS + case ENOBUFS: return DRWAV_NO_SPACE; + #endif + #ifdef EISCONN + case EISCONN: return DRWAV_ALREADY_CONNECTED; + #endif + #ifdef ENOTCONN + case ENOTCONN: return DRWAV_NOT_CONNECTED; + #endif + #ifdef ESHUTDOWN + case ESHUTDOWN: return DRWAV_ERROR; + #endif + #ifdef ETOOMANYREFS + case ETOOMANYREFS: return DRWAV_ERROR; + #endif + #ifdef ETIMEDOUT + case ETIMEDOUT: return DRWAV_TIMEOUT; + #endif + #ifdef ECONNREFUSED + case ECONNREFUSED: return DRWAV_CONNECTION_REFUSED; + #endif + #ifdef EHOSTDOWN + case EHOSTDOWN: return DRWAV_NO_HOST; + #endif + #ifdef EHOSTUNREACH + case EHOSTUNREACH: return DRWAV_NO_HOST; + #endif + #ifdef EALREADY + case EALREADY: return DRWAV_IN_PROGRESS; + #endif + #ifdef EINPROGRESS + case EINPROGRESS: return DRWAV_IN_PROGRESS; + #endif + #ifdef ESTALE + case ESTALE: return DRWAV_INVALID_FILE; + #endif + #ifdef EUCLEAN + case EUCLEAN: return DRWAV_ERROR; + #endif + #ifdef ENOTNAM + case ENOTNAM: return DRWAV_ERROR; + #endif + #ifdef ENAVAIL + case ENAVAIL: return DRWAV_ERROR; + #endif + #ifdef EISNAM + case EISNAM: return DRWAV_ERROR; + #endif + #ifdef EREMOTEIO + case EREMOTEIO: return DRWAV_IO_ERROR; + #endif + #ifdef EDQUOT + case EDQUOT: return DRWAV_NO_SPACE; + #endif + #ifdef ENOMEDIUM + case ENOMEDIUM: return DRWAV_DOES_NOT_EXIST; + #endif + #ifdef EMEDIUMTYPE + case EMEDIUMTYPE: return DRWAV_ERROR; + #endif + #ifdef ECANCELED + case ECANCELED: return DRWAV_CANCELLED; + #endif + #ifdef ENOKEY + case ENOKEY: return DRWAV_ERROR; + #endif + #ifdef EKEYEXPIRED + case EKEYEXPIRED: return DRWAV_ERROR; + #endif + #ifdef EKEYREVOKED + case EKEYREVOKED: return DRWAV_ERROR; + #endif + #ifdef EKEYREJECTED + case EKEYREJECTED: return DRWAV_ERROR; + #endif + #ifdef EOWNERDEAD + case EOWNERDEAD: return DRWAV_ERROR; + #endif + #ifdef ENOTRECOVERABLE + case ENOTRECOVERABLE: return DRWAV_ERROR; + #endif + #ifdef ERFKILL + case ERFKILL: return DRWAV_ERROR; + #endif + #ifdef EHWPOISON + case EHWPOISON: return DRWAV_ERROR; + #endif + default: return DRWAV_ERROR; + } +} + +static drwav_result drwav_fopen(FILE** ppFile, const char* pFilePath, const char* pOpenMode) +{ +#if _MSC_VER && _MSC_VER >= 1400 + errno_t err; +#endif + + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if _MSC_VER && _MSC_VER >= 1400 + err = fopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } +#else +#if defined(_WIN32) || defined(__APPLE__) + *ppFile = fopen(pFilePath, pOpenMode); +#else + #if defined(_FILE_OFFSET_BITS) && _FILE_OFFSET_BITS == 64 && defined(_LARGEFILE64_SOURCE) + *ppFile = fopen64(pFilePath, pOpenMode); + #else + *ppFile = fopen(pFilePath, pOpenMode); + #endif +#endif + if (*ppFile == NULL) { + drwav_result result = drwav_result_from_errno(errno); + if (result == DRWAV_SUCCESS) { + result = DRWAV_ERROR; /* Just a safety check to make sure we never ever return success when pFile == NULL. */ + } + + return result; + } +#endif + + return DRWAV_SUCCESS; +} + +/* +_wfopen() isn't always available in all compilation environments. + + * Windows only. + * MSVC seems to support it universally as far back as VC6 from what I can tell (haven't checked further back). + * MinGW-64 (both 32- and 64-bit) seems to support it. + * MinGW wraps it in !defined(__STRICT_ANSI__). + * OpenWatcom wraps it in !defined(_NO_EXT_KEYS). + +This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() +fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. +*/ +#if defined(_WIN32) + #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) + #define DRWAV_HAS_WFOPEN + #endif +#endif + +static drwav_result drwav_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_t* pOpenMode, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppFile != NULL) { + *ppFile = NULL; /* Safety. */ + } + + if (pFilePath == NULL || pOpenMode == NULL || ppFile == NULL) { + return DRWAV_INVALID_ARGS; + } + +#if defined(DRWAV_HAS_WFOPEN) + { + /* Use _wfopen() on Windows. */ + #if defined(_MSC_VER) && _MSC_VER >= 1400 + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return drwav_result_from_errno(err); + } + #else + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return drwav_result_from_errno(errno); + } + #endif + (void)pAllocationCallbacks; + } +#else + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ + { + mbstate_t mbs; + size_t lenMB; + const wchar_t* pFilePathTemp = pFilePath; + char* pFilePathMB = NULL; + char pOpenModeMB[32] = {0}; + + /* Get the length first. */ + DRWAV_ZERO_OBJECT(&mbs); + lenMB = wcsrtombs(NULL, &pFilePathTemp, 0, &mbs); + if (lenMB == (size_t)-1) { + return drwav_result_from_errno(errno); + } + + pFilePathMB = (char*)drwav__malloc_from_callbacks(lenMB + 1, pAllocationCallbacks); + if (pFilePathMB == NULL) { + return DRWAV_OUT_OF_MEMORY; + } + + pFilePathTemp = pFilePath; + DRWAV_ZERO_OBJECT(&mbs); + wcsrtombs(pFilePathMB, &pFilePathTemp, lenMB + 1, &mbs); + + /* The open mode should always consist of ASCII characters so we should be able to do a trivial conversion. */ + { + size_t i = 0; + for (;;) { + if (pOpenMode[i] == 0) { + pOpenModeMB[i] = '\0'; + break; + } + + pOpenModeMB[i] = (char)pOpenMode[i]; + i += 1; + } + } + + *ppFile = fopen(pFilePathMB, pOpenModeMB); + + drwav__free_from_callbacks(pFilePathMB, pAllocationCallbacks); + } + + if (*ppFile == NULL) { + return DRWAV_ERROR; + } +#endif + + return DRWAV_SUCCESS; +} + + +static size_t drwav__on_read_stdio(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + return fread(pBufferOut, 1, bytesToRead, (FILE*)pUserData); +} + +static size_t drwav__on_write_stdio(void* pUserData, const void* pData, size_t bytesToWrite) +{ + return fwrite(pData, 1, bytesToWrite, (FILE*)pUserData); +} + +static drwav_bool32 drwav__on_seek_stdio(void* pUserData, int offset, drwav_seek_origin origin) +{ + return fseek((FILE*)pUserData, offset, (origin == drwav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; +} + +DRWAV_API drwav_bool32 drwav_init_file(drwav* pWav, const char* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file__internal_FILE(drwav* pWav, FILE* pFile, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit(pWav, drwav__on_read_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init__internal(pWav, onChunk, pChunkUserData, flags); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_file_ex(drwav* pWav, const char* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "rb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_w(drwav* pWav, const wchar_t* filename, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_ex_w(pWav, filename, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_ex_w(drwav* pWav, const wchar_t* filename, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"rb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file__internal_FILE(pWav, pFile, onChunk, pChunkUserData, flags, pAllocationCallbacks); +} + + +static drwav_bool32 drwav_init_file_write__internal_FILE(drwav* pWav, FILE* pFile, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav_bool32 result; + + result = drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_stdio, drwav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + result = drwav_init_write__internal(pWav, pFormat, totalSampleCount); + if (result != DRWAV_TRUE) { + fclose(pFile); + return result; + } + + return DRWAV_TRUE; +} + +static drwav_bool32 drwav_init_file_write__internal(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_fopen(&pFile, filename, "wb") != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +static drwav_bool32 drwav_init_file_write_w__internal(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + FILE* pFile; + if (drwav_wfopen(&pFile, filename, L"wb", pAllocationCallbacks) != DRWAV_SUCCESS) { + return DRWAV_FALSE; + } + + /* This takes ownership of the FILE* object. */ + return drwav_init_file_write__internal_FILE(pWav, pFile, pFormat, totalSampleCount, isSequential, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write(drwav* pWav, const char* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames(drwav* pWav, const char* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_file_write_w__internal(pWav, filename, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_file_write_sequential_pcm_frames_w(drwav* pWav, const wchar_t* filename, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_file_write_sequential_w(pWav, filename, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} +#endif /* DR_WAV_NO_STDIO */ + + +static size_t drwav__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStream.dataSize >= pWav->memoryStream.currentReadPos); + + bytesRemaining = pWav->memoryStream.dataSize - pWav->memoryStream.currentReadPos; + if (bytesToRead > bytesRemaining) { + bytesToRead = bytesRemaining; + } + + if (bytesToRead > 0) { + DRWAV_COPY_MEMORY(pBufferOut, pWav->memoryStream.data + pWav->memoryStream.currentReadPos, bytesToRead); + pWav->memoryStream.currentReadPos += bytesToRead; + } + + return bytesToRead; +} + +static drwav_bool32 drwav__on_seek_memory(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStream.currentReadPos < (size_t)-offset) { + return DRWAV_FALSE; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStream.currentReadPos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStream.dataSize) { + pWav->memoryStream.currentReadPos = offset; + } else { + return DRWAV_FALSE; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +static size_t drwav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) +{ + drwav* pWav = (drwav*)pUserData; + size_t bytesRemaining; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(pWav->memoryStreamWrite.dataCapacity >= pWav->memoryStreamWrite.currentWritePos); + + bytesRemaining = pWav->memoryStreamWrite.dataCapacity - pWav->memoryStreamWrite.currentWritePos; + if (bytesRemaining < bytesToWrite) { + /* Need to reallocate. */ + void* pNewData; + size_t newDataCapacity = (pWav->memoryStreamWrite.dataCapacity == 0) ? 256 : pWav->memoryStreamWrite.dataCapacity * 2; + + /* If doubling wasn't enough, just make it the minimum required size to write the data. */ + if ((newDataCapacity - pWav->memoryStreamWrite.currentWritePos) < bytesToWrite) { + newDataCapacity = pWav->memoryStreamWrite.currentWritePos + bytesToWrite; + } + + pNewData = drwav__realloc_from_callbacks(*pWav->memoryStreamWrite.ppData, newDataCapacity, pWav->memoryStreamWrite.dataCapacity, &pWav->allocationCallbacks); + if (pNewData == NULL) { + return 0; + } + + *pWav->memoryStreamWrite.ppData = pNewData; + pWav->memoryStreamWrite.dataCapacity = newDataCapacity; + } + + DRWAV_COPY_MEMORY(((drwav_uint8*)(*pWav->memoryStreamWrite.ppData)) + pWav->memoryStreamWrite.currentWritePos, pDataIn, bytesToWrite); + + pWav->memoryStreamWrite.currentWritePos += bytesToWrite; + if (pWav->memoryStreamWrite.dataSize < pWav->memoryStreamWrite.currentWritePos) { + pWav->memoryStreamWrite.dataSize = pWav->memoryStreamWrite.currentWritePos; + } + + *pWav->memoryStreamWrite.pDataSize = pWav->memoryStreamWrite.dataSize; + + return bytesToWrite; +} + +static drwav_bool32 drwav__on_seek_memory_write(void* pUserData, int offset, drwav_seek_origin origin) +{ + drwav* pWav = (drwav*)pUserData; + DRWAV_ASSERT(pWav != NULL); + + if (origin == drwav_seek_origin_current) { + if (offset > 0) { + if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { + offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); /* Trying to seek too far forward. */ + } + } else { + if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { + offset = -(int)pWav->memoryStreamWrite.currentWritePos; /* Trying to seek too far backwards. */ + } + } + + /* This will never underflow thanks to the clamps above. */ + pWav->memoryStreamWrite.currentWritePos += offset; + } else { + if ((drwav_uint32)offset <= pWav->memoryStreamWrite.dataSize) { + pWav->memoryStreamWrite.currentWritePos = offset; + } else { + pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; /* Trying to seek too far forward. */ + } + } + + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_init_memory(drwav* pWav, const void* data, size_t dataSize, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_ex(pWav, data, dataSize, NULL, NULL, 0, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_ex(drwav* pWav, const void* data, size_t dataSize, drwav_chunk_proc onChunk, void* pChunkUserData, drwav_uint32 flags, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (data == NULL || dataSize == 0) { + return DRWAV_FALSE; + } + + if (!drwav_preinit(pWav, drwav__on_read_memory, drwav__on_seek_memory, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStream.data = (const drwav_uint8*)data; + pWav->memoryStream.dataSize = dataSize; + pWav->memoryStream.currentReadPos = 0; + + return drwav_init__internal(pWav, onChunk, pChunkUserData, flags); +} + + +static drwav_bool32 drwav_init_memory_write__internal(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, drwav_bool32 isSequential, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (ppData == NULL || pDataSize == NULL) { + return DRWAV_FALSE; + } + + *ppData = NULL; /* Important because we're using realloc()! */ + *pDataSize = 0; + + if (!drwav_preinit_write(pWav, pFormat, isSequential, drwav__on_write_memory, drwav__on_seek_memory_write, pWav, pAllocationCallbacks)) { + return DRWAV_FALSE; + } + + pWav->memoryStreamWrite.ppData = ppData; + pWav->memoryStreamWrite.pDataSize = pDataSize; + pWav->memoryStreamWrite.dataSize = 0; + pWav->memoryStreamWrite.dataCapacity = 0; + pWav->memoryStreamWrite.currentWritePos = 0; + + return drwav_init_write__internal(pWav, pFormat, totalSampleCount); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, 0, DRWAV_FALSE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalSampleCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + return drwav_init_memory_write__internal(pWav, ppData, pDataSize, pFormat, totalSampleCount, DRWAV_TRUE, pAllocationCallbacks); +} + +DRWAV_API drwav_bool32 drwav_init_memory_write_sequential_pcm_frames(drwav* pWav, void** ppData, size_t* pDataSize, const drwav_data_format* pFormat, drwav_uint64 totalPCMFrameCount, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pFormat == NULL) { + return DRWAV_FALSE; + } + + return drwav_init_memory_write_sequential(pWav, ppData, pDataSize, pFormat, totalPCMFrameCount*pFormat->channels, pAllocationCallbacks); +} + + + +DRWAV_API drwav_result drwav_uninit(drwav* pWav) +{ + drwav_result result = DRWAV_SUCCESS; + + if (pWav == NULL) { + return DRWAV_INVALID_ARGS; + } + + /* + If the drwav object was opened in write mode we'll need to finalize a few things: + - Make sure the "data" chunk is aligned to 16-bits for RIFF containers, or 64 bits for W64 containers. + - Set the size of the "data" chunk. + */ + if (pWav->onWrite != NULL) { + drwav_uint32 paddingSize = 0; + + /* Padding. Do not adjust pWav->dataChunkDataSize - this should not include the padding. */ + if (pWav->container == drwav_container_riff || pWav->container == drwav_container_rf64) { + paddingSize = drwav__chunk_padding_size_riff(pWav->dataChunkDataSize); + } else { + paddingSize = drwav__chunk_padding_size_w64(pWav->dataChunkDataSize); + } + + if (paddingSize > 0) { + drwav_uint64 paddingData = 0; + drwav__write(pWav, &paddingData, paddingSize); /* Byte order does not matter for this. */ + } + + /* + Chunk sizes. When using sequential mode, these will have been filled in at initialization time. We only need + to do this when using non-sequential mode. + */ + if (pWav->onSeek && !pWav->isSequentialWrite) { + if (pWav->container == drwav_container_riff) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 4, drwav_seek_origin_start)) { + drwav_uint32 riffChunkSize = drwav__riff_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, riffChunkSize); + } + + /* the "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 4, drwav_seek_origin_start)) { + drwav_uint32 dataChunkSize = drwav__data_chunk_size_riff(pWav->dataChunkDataSize); + drwav__write_u32ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_w64) { + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, 16, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos + 16, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_w64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } else if (pWav->container == drwav_container_rf64) { + /* We only need to update the ds64 chunk. The "RIFF" and "data" chunks always have their sizes set to 0xFFFFFFFF for RF64. */ + int ds64BodyPos = 12 + 8; + + /* The "RIFF" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, drwav_seek_origin_start)) { + drwav_uint64 riffChunkSize = drwav__riff_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, riffChunkSize); + } + + /* The "data" chunk size. */ + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, drwav_seek_origin_start)) { + drwav_uint64 dataChunkSize = drwav__data_chunk_size_rf64(pWav->dataChunkDataSize); + drwav__write_u64ne_to_le(pWav, dataChunkSize); + } + } + } + + /* Validation for sequential mode. */ + if (pWav->isSequentialWrite) { + if (pWav->dataChunkDataSize != pWav->dataChunkDataSizeTargetWrite) { + result = DRWAV_INVALID_FILE; + } + } + } + +#ifndef DR_WAV_NO_STDIO + /* + If we opened the file with drwav_open_file() we will want to close the file handle. We can know whether or not drwav_open_file() + was used by looking at the onRead and onSeek callbacks. + */ + if (pWav->onRead == drwav__on_read_stdio || pWav->onWrite == drwav__on_write_stdio) { + fclose((FILE*)pWav->pUserData); + } +#endif + + return result; +} + + + +DRWAV_API size_t drwav_read_raw(drwav* pWav, size_t bytesToRead, void* pBufferOut) +{ + size_t bytesRead; + + if (pWav == NULL || bytesToRead == 0) { + return 0; + } + + if (bytesToRead > pWav->bytesRemaining) { + bytesToRead = (size_t)pWav->bytesRemaining; + } + + if (pBufferOut != NULL) { + bytesRead = pWav->onRead(pWav->pUserData, pBufferOut, bytesToRead); + } else { + /* We need to seek. If we fail, we need to read-and-discard to make sure we get a good byte count. */ + bytesRead = 0; + while (bytesRead < bytesToRead) { + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > 0x7FFFFFFF) { + bytesToSeek = 0x7FFFFFFF; + } + + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, drwav_seek_origin_current) == DRWAV_FALSE) { + break; + } + + bytesRead += bytesToSeek; + } + + /* When we get here we may need to read-and-discard some data. */ + while (bytesRead < bytesToRead) { + drwav_uint8 buffer[4096]; + size_t bytesSeeked; + size_t bytesToSeek = (bytesToRead - bytesRead); + if (bytesToSeek > sizeof(buffer)) { + bytesToSeek = sizeof(buffer); + } + + bytesSeeked = pWav->onRead(pWav->pUserData, buffer, bytesToSeek); + bytesRead += bytesSeeked; + + if (bytesSeeked < bytesToSeek) { + break; /* Reached the end. */ + } + } + } + + pWav->bytesRemaining -= bytesRead; + return bytesRead; +} + + + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_le(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 bytesToRead; /* Intentionally uint64 instead of size_t so we can do a check that we're not reading too much on 32-bit builds. */ + + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + /* Cannot use this function for compressed formats. */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + return 0; + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + bytesToRead = framesToRead * bytesPerFrame; + if (bytesToRead > DRWAV_SIZE_MAX) { + bytesToRead = (DRWAV_SIZE_MAX / bytesPerFrame) * bytesPerFrame; /* Round the number of bytes to read to a clean frame boundary. */ + } + + /* + Doing an explicit check here just to make it clear that we don't want to be attempt to read anything if there's no bytes to read. There + *could* be a time where it evaluates to 0 due to overflowing. + */ + if (bytesToRead == 0) { + return 0; + } + + return drwav_read_raw(pWav, (size_t)bytesToRead, pBufferOut) / bytesPerFrame; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_be(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + + if (pBufferOut != NULL) { + drwav__bswap_samples(pBufferOut, framesRead*pWav->channels, drwav_get_bytes_per_pcm_frame(pWav)/pWav->channels, pWav->translatedFormatTag); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames(drwav* pWav, drwav_uint64 framesToRead, void* pBufferOut) +{ + if (drwav__is_little_endian()) { + return drwav_read_pcm_frames_le(pWav, framesToRead, pBufferOut); + } else { + return drwav_read_pcm_frames_be(pWav, framesToRead, pBufferOut); + } +} + + + +DRWAV_API drwav_bool32 drwav_seek_to_first_pcm_frame(drwav* pWav) +{ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; /* No seeking in write mode. */ + } + + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, drwav_seek_origin_start)) { + return DRWAV_FALSE; + } + + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + pWav->compressed.iCurrentPCMFrame = 0; + + /* Cached data needs to be cleared for compressed formats. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->msadpcm); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + DRWAV_ZERO_OBJECT(&pWav->ima); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + } + + pWav->bytesRemaining = pWav->dataChunkDataSize; + return DRWAV_TRUE; +} + +DRWAV_API drwav_bool32 drwav_seek_to_pcm_frame(drwav* pWav, drwav_uint64 targetFrameIndex) +{ + /* Seeking should be compatible with wave files > 2GB. */ + + if (pWav == NULL || pWav->onSeek == NULL) { + return DRWAV_FALSE; + } + + /* No seeking in write mode. */ + if (pWav->onWrite != NULL) { + return DRWAV_FALSE; + } + + /* If there are no samples, just return DRWAV_TRUE without doing anything. */ + if (pWav->totalPCMFrameCount == 0) { + return DRWAV_TRUE; + } + + /* Make sure the sample is clamped. */ + if (targetFrameIndex >= pWav->totalPCMFrameCount) { + targetFrameIndex = pWav->totalPCMFrameCount - 1; + } + + /* + For compressed formats we just use a slow generic seek. If we are seeking forward we just seek forward. If we are going backwards we need + to seek back to the start. + */ + if (drwav__is_compressed_format_tag(pWav->translatedFormatTag)) { + /* TODO: This can be optimized. */ + + /* + If we're seeking forward it's simple - just keep reading samples until we hit the sample we're requesting. If we're seeking backwards, + we first need to seek back to the start and then just do the same thing as a forward seek. + */ + if (targetFrameIndex < pWav->compressed.iCurrentPCMFrame) { + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + } + + if (targetFrameIndex > pWav->compressed.iCurrentPCMFrame) { + drwav_uint64 offsetInFrames = targetFrameIndex - pWav->compressed.iCurrentPCMFrame; + + drwav_int16 devnull[2048]; + while (offsetInFrames > 0) { + drwav_uint64 framesRead = 0; + drwav_uint64 framesToRead = offsetInFrames; + if (framesToRead > drwav_countof(devnull)/pWav->channels) { + framesToRead = drwav_countof(devnull)/pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, devnull); + } else if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + framesRead = drwav_read_pcm_frames_s16__ima(pWav, framesToRead, devnull); + } else { + DRWAV_ASSERT(DRWAV_FALSE); /* If this assertion is triggered it means I've implemented a new compressed format but forgot to add a branch for it here. */ + } + + if (framesRead != framesToRead) { + return DRWAV_FALSE; + } + + offsetInFrames -= framesRead; + } + } + } else { + drwav_uint64 totalSizeInBytes; + drwav_uint64 currentBytePos; + drwav_uint64 targetBytePos; + drwav_uint64 offset; + + totalSizeInBytes = pWav->totalPCMFrameCount * drwav_get_bytes_per_pcm_frame(pWav); + DRWAV_ASSERT(totalSizeInBytes >= pWav->bytesRemaining); + + currentBytePos = totalSizeInBytes - pWav->bytesRemaining; + targetBytePos = targetFrameIndex * drwav_get_bytes_per_pcm_frame(pWav); + + if (currentBytePos < targetBytePos) { + /* Offset forwards. */ + offset = (targetBytePos - currentBytePos); + } else { + /* Offset backwards. */ + if (!drwav_seek_to_first_pcm_frame(pWav)) { + return DRWAV_FALSE; + } + offset = targetBytePos; + } + + while (offset > 0) { + int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); + if (!pWav->onSeek(pWav->pUserData, offset32, drwav_seek_origin_current)) { + return DRWAV_FALSE; + } + + pWav->bytesRemaining -= offset32; + offset -= offset32; + } + } + + return DRWAV_TRUE; +} + + +DRWAV_API size_t drwav_write_raw(drwav* pWav, size_t bytesToWrite, const void* pData) +{ + size_t bytesWritten; + + if (pWav == NULL || bytesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesWritten = pWav->onWrite(pWav->pUserData, pData, bytesToWrite); + pWav->dataChunkDataSize += bytesWritten; + + return bytesWritten; +} + + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_le(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + while (bytesToWrite > 0) { + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, pRunningData); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames_be(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + drwav_uint64 bytesToWrite; + drwav_uint64 bytesWritten; + drwav_uint32 bytesPerSample; + const drwav_uint8* pRunningData; + + if (pWav == NULL || framesToWrite == 0 || pData == NULL) { + return 0; + } + + bytesToWrite = ((framesToWrite * pWav->channels * pWav->bitsPerSample) / 8); + if (bytesToWrite > DRWAV_SIZE_MAX) { + return 0; + } + + bytesWritten = 0; + pRunningData = (const drwav_uint8*)pData; + + bytesPerSample = drwav_get_bytes_per_pcm_frame(pWav) / pWav->channels; + + while (bytesToWrite > 0) { + drwav_uint8 temp[4096]; + drwav_uint32 sampleCount; + size_t bytesJustWritten; + drwav_uint64 bytesToWriteThisIteration; + + bytesToWriteThisIteration = bytesToWrite; + DRWAV_ASSERT(bytesToWriteThisIteration <= DRWAV_SIZE_MAX); /* <-- This is checked above. */ + + /* + WAV files are always little-endian. We need to byte swap on big-endian architectures. Since our input buffer is read-only we need + to use an intermediary buffer for the conversion. + */ + sampleCount = sizeof(temp)/bytesPerSample; + + if (bytesToWriteThisIteration > ((drwav_uint64)sampleCount)*bytesPerSample) { + bytesToWriteThisIteration = ((drwav_uint64)sampleCount)*bytesPerSample; + } + + DRWAV_COPY_MEMORY(temp, pRunningData, (size_t)bytesToWriteThisIteration); + drwav__bswap_samples(temp, sampleCount, bytesPerSample, pWav->translatedFormatTag); + + bytesJustWritten = drwav_write_raw(pWav, (size_t)bytesToWriteThisIteration, temp); + if (bytesJustWritten == 0) { + break; + } + + bytesToWrite -= bytesJustWritten; + bytesWritten += bytesJustWritten; + pRunningData += bytesJustWritten; + } + + return (bytesWritten * 8) / pWav->bitsPerSample / pWav->channels; +} + +DRWAV_API drwav_uint64 drwav_write_pcm_frames(drwav* pWav, drwav_uint64 framesToWrite, const void* pData) +{ + if (drwav__is_little_endian()) { + return drwav_write_pcm_frames_le(pWav, framesToWrite, pData); + } else { + return drwav_write_pcm_frames_be(pWav, framesToWrite, pData); + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached frames we need to load a new block. */ + if (pWav->msadpcm.cachedFrameCount == 0 && pWav->msadpcm.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[7]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 1); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 3); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 5); + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_uint8 header[14]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + pWav->msadpcm.predictor[0] = header[0]; + pWav->msadpcm.predictor[1] = header[1]; + pWav->msadpcm.delta[0] = drwav__bytes_to_s16(header + 2); + pWav->msadpcm.delta[1] = drwav__bytes_to_s16(header + 4); + pWav->msadpcm.prevFrames[0][1] = (drwav_int32)drwav__bytes_to_s16(header + 6); + pWav->msadpcm.prevFrames[1][1] = (drwav_int32)drwav__bytes_to_s16(header + 8); + pWav->msadpcm.prevFrames[0][0] = (drwav_int32)drwav__bytes_to_s16(header + 10); + pWav->msadpcm.prevFrames[1][0] = (drwav_int32)drwav__bytes_to_s16(header + 12); + + pWav->msadpcm.cachedFrames[0] = pWav->msadpcm.prevFrames[0][0]; + pWav->msadpcm.cachedFrames[1] = pWav->msadpcm.prevFrames[1][0]; + pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.cachedFrameCount = 2; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->msadpcm.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample = 0; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->msadpcm.cachedFrames[(drwav_countof(pWav->msadpcm.cachedFrames) - (pWav->msadpcm.cachedFrameCount*pWav->channels)) + iSample]; + } + + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->msadpcm.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->msadpcm.cachedFrameCount == 0) { + if (pWav->msadpcm.bytesRemainingInBlock == 0) { + continue; + } else { + static drwav_int32 adaptationTable[] = { + 230, 230, 230, 230, 307, 409, 512, 614, + 768, 614, 512, 409, 307, 230, 230, 230 + }; + static drwav_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static drwav_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + + drwav_uint8 nibbles; + drwav_int32 nibble0; + drwav_int32 nibble1; + + if (pWav->onRead(pWav->pUserData, &nibbles, 1) != 1) { + return totalFramesRead; + } + pWav->msadpcm.bytesRemainingInBlock -= 1; + + /* TODO: Optimize away these if statements. */ + nibble0 = ((nibbles & 0xF0) >> 4); if ((nibbles & 0x80)) { nibble0 |= 0xFFFFFFF0UL; } + nibble1 = ((nibbles & 0x0F) >> 0); if ((nibbles & 0x08)) { nibble1 |= 0xFFFFFFF0UL; } + + if (pWav->channels == 1) { + /* Mono. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + newSample1 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[0]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample1; + + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 2; + } else { + /* Stereo. */ + drwav_int32 newSample0; + drwav_int32 newSample1; + + /* Left. */ + newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; + newSample0 += nibble0 * pWav->msadpcm.delta[0]; + newSample0 = drwav_clamp(newSample0, -32768, 32767); + + pWav->msadpcm.delta[0] = (adaptationTable[((nibbles & 0xF0) >> 4)] * pWav->msadpcm.delta[0]) >> 8; + if (pWav->msadpcm.delta[0] < 16) { + pWav->msadpcm.delta[0] = 16; + } + + pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; + pWav->msadpcm.prevFrames[0][1] = newSample0; + + + /* Right. */ + newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; + newSample1 += nibble1 * pWav->msadpcm.delta[1]; + newSample1 = drwav_clamp(newSample1, -32768, 32767); + + pWav->msadpcm.delta[1] = (adaptationTable[((nibbles & 0x0F) >> 0)] * pWav->msadpcm.delta[1]) >> 8; + if (pWav->msadpcm.delta[1] < 16) { + pWav->msadpcm.delta[1] = 16; + } + + pWav->msadpcm.prevFrames[1][0] = pWav->msadpcm.prevFrames[1][1]; + pWav->msadpcm.prevFrames[1][1] = newSample1; + + pWav->msadpcm.cachedFrames[2] = newSample0; + pWav->msadpcm.cachedFrames[3] = newSample1; + pWav->msadpcm.cachedFrameCount = 1; + } + } + } + } + + return totalFramesRead; +} + + +static drwav_uint64 drwav_read_pcm_frames_s16__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead = 0; + drwav_uint32 iChannel; + + static drwav_int32 indexTable[16] = { + -1, -1, -1, -1, 2, 4, 6, 8, + -1, -1, -1, -1, 2, 4, 6, 8 + }; + + static drwav_int32 stepTable[89] = { + 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, + 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, + 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, + 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, + 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066, + 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358, + 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899, + 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767 + }; + + DRWAV_ASSERT(pWav != NULL); + DRWAV_ASSERT(framesToRead > 0); + + /* TODO: Lots of room for optimization here. */ + + while (framesToRead > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + /* If there are no cached samples we need to load a new block. */ + if (pWav->ima.cachedFrameCount == 0 && pWav->ima.bytesRemainingInBlock == 0) { + if (pWav->channels == 1) { + /* Mono. */ + drwav_uint8 header[4]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[0]; + pWav->ima.cachedFrameCount = 1; + } else { + /* Stereo. */ + drwav_uint8 header[8]; + if (pWav->onRead(pWav->pUserData, header, sizeof(header)) != sizeof(header)) { + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); + + if (header[2] >= drwav_countof(stepTable) || header[6] >= drwav_countof(stepTable)) { + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, drwav_seek_origin_current); + pWav->ima.bytesRemainingInBlock = 0; + return totalFramesRead; /* Invalid data. */ + } + + pWav->ima.predictor[0] = drwav__bytes_to_s16(header + 0); + pWav->ima.stepIndex[0] = header[2]; + pWav->ima.predictor[1] = drwav__bytes_to_s16(header + 4); + pWav->ima.stepIndex[1] = header[6]; + + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 2] = pWav->ima.predictor[0]; + pWav->ima.cachedFrames[drwav_countof(pWav->ima.cachedFrames) - 1] = pWav->ima.predictor[1]; + pWav->ima.cachedFrameCount = 1; + } + } + + /* Output anything that's cached. */ + while (framesToRead > 0 && pWav->ima.cachedFrameCount > 0 && pWav->compressed.iCurrentPCMFrame < pWav->totalPCMFrameCount) { + if (pBufferOut != NULL) { + drwav_uint32 iSample; + for (iSample = 0; iSample < pWav->channels; iSample += 1) { + pBufferOut[iSample] = (drwav_int16)pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + iSample]; + } + pBufferOut += pWav->channels; + } + + framesToRead -= 1; + totalFramesRead += 1; + pWav->compressed.iCurrentPCMFrame += 1; + pWav->ima.cachedFrameCount -= 1; + } + + if (framesToRead == 0) { + return totalFramesRead; + } + + /* + If there's nothing left in the cache, just go ahead and load more. If there's nothing left to load in the current block we just continue to the next + loop iteration which will trigger the loading of a new block. + */ + if (pWav->ima.cachedFrameCount == 0) { + if (pWav->ima.bytesRemainingInBlock == 0) { + continue; + } else { + /* + From what I can tell with stereo streams, it looks like every 4 bytes (8 samples) is for one channel. So it goes 4 bytes for the + left channel, 4 bytes for the right channel. + */ + pWav->ima.cachedFrameCount = 8; + for (iChannel = 0; iChannel < pWav->channels; ++iChannel) { + drwav_uint32 iByte; + drwav_uint8 nibbles[4]; + if (pWav->onRead(pWav->pUserData, &nibbles, 4) != 4) { + pWav->ima.cachedFrameCount = 0; + return totalFramesRead; + } + pWav->ima.bytesRemainingInBlock -= 4; + + for (iByte = 0; iByte < 4; ++iByte) { + drwav_uint8 nibble0 = ((nibbles[iByte] & 0x0F) >> 0); + drwav_uint8 nibble1 = ((nibbles[iByte] & 0xF0) >> 4); + + drwav_int32 step = stepTable[pWav->ima.stepIndex[iChannel]]; + drwav_int32 predictor = pWav->ima.predictor[iChannel]; + + drwav_int32 diff = step >> 3; + if (nibble0 & 1) diff += step >> 2; + if (nibble0 & 2) diff += step >> 1; + if (nibble0 & 4) diff += step; + if (nibble0 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble0], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+0)*pWav->channels + iChannel] = predictor; + + + step = stepTable[pWav->ima.stepIndex[iChannel]]; + predictor = pWav->ima.predictor[iChannel]; + + diff = step >> 3; + if (nibble1 & 1) diff += step >> 2; + if (nibble1 & 2) diff += step >> 1; + if (nibble1 & 4) diff += step; + if (nibble1 & 8) diff = -diff; + + predictor = drwav_clamp(predictor + diff, -32768, 32767); + pWav->ima.predictor[iChannel] = predictor; + pWav->ima.stepIndex[iChannel] = drwav_clamp(pWav->ima.stepIndex[iChannel] + indexTable[nibble1], 0, (drwav_int32)drwav_countof(stepTable)-1); + pWav->ima.cachedFrames[(drwav_countof(pWav->ima.cachedFrames) - (pWav->ima.cachedFrameCount*pWav->channels)) + (iByte*2+1)*pWav->channels + iChannel] = predictor; + } + } + } + } + } + + return totalFramesRead; +} + + +#ifndef DR_WAV_NO_CONVERSION_API +static unsigned short g_drwavAlawTable[256] = { + 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, + 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, + 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, + 0xD500, 0xD700, 0xD100, 0xD300, 0xDD00, 0xDF00, 0xD900, 0xDB00, 0xC500, 0xC700, 0xC100, 0xC300, 0xCD00, 0xCF00, 0xC900, 0xCB00, + 0xFEA8, 0xFEB8, 0xFE88, 0xFE98, 0xFEE8, 0xFEF8, 0xFEC8, 0xFED8, 0xFE28, 0xFE38, 0xFE08, 0xFE18, 0xFE68, 0xFE78, 0xFE48, 0xFE58, + 0xFFA8, 0xFFB8, 0xFF88, 0xFF98, 0xFFE8, 0xFFF8, 0xFFC8, 0xFFD8, 0xFF28, 0xFF38, 0xFF08, 0xFF18, 0xFF68, 0xFF78, 0xFF48, 0xFF58, + 0xFAA0, 0xFAE0, 0xFA20, 0xFA60, 0xFBA0, 0xFBE0, 0xFB20, 0xFB60, 0xF8A0, 0xF8E0, 0xF820, 0xF860, 0xF9A0, 0xF9E0, 0xF920, 0xF960, + 0xFD50, 0xFD70, 0xFD10, 0xFD30, 0xFDD0, 0xFDF0, 0xFD90, 0xFDB0, 0xFC50, 0xFC70, 0xFC10, 0xFC30, 0xFCD0, 0xFCF0, 0xFC90, 0xFCB0, + 0x1580, 0x1480, 0x1780, 0x1680, 0x1180, 0x1080, 0x1380, 0x1280, 0x1D80, 0x1C80, 0x1F80, 0x1E80, 0x1980, 0x1880, 0x1B80, 0x1A80, + 0x0AC0, 0x0A40, 0x0BC0, 0x0B40, 0x08C0, 0x0840, 0x09C0, 0x0940, 0x0EC0, 0x0E40, 0x0FC0, 0x0F40, 0x0CC0, 0x0C40, 0x0DC0, 0x0D40, + 0x5600, 0x5200, 0x5E00, 0x5A00, 0x4600, 0x4200, 0x4E00, 0x4A00, 0x7600, 0x7200, 0x7E00, 0x7A00, 0x6600, 0x6200, 0x6E00, 0x6A00, + 0x2B00, 0x2900, 0x2F00, 0x2D00, 0x2300, 0x2100, 0x2700, 0x2500, 0x3B00, 0x3900, 0x3F00, 0x3D00, 0x3300, 0x3100, 0x3700, 0x3500, + 0x0158, 0x0148, 0x0178, 0x0168, 0x0118, 0x0108, 0x0138, 0x0128, 0x01D8, 0x01C8, 0x01F8, 0x01E8, 0x0198, 0x0188, 0x01B8, 0x01A8, + 0x0058, 0x0048, 0x0078, 0x0068, 0x0018, 0x0008, 0x0038, 0x0028, 0x00D8, 0x00C8, 0x00F8, 0x00E8, 0x0098, 0x0088, 0x00B8, 0x00A8, + 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, + 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 +}; + +static unsigned short g_drwavMulawTable[256] = { + 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, + 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, + 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, + 0xF0C4, 0xF144, 0xF1C4, 0xF244, 0xF2C4, 0xF344, 0xF3C4, 0xF444, 0xF4C4, 0xF544, 0xF5C4, 0xF644, 0xF6C4, 0xF744, 0xF7C4, 0xF844, + 0xF8A4, 0xF8E4, 0xF924, 0xF964, 0xF9A4, 0xF9E4, 0xFA24, 0xFA64, 0xFAA4, 0xFAE4, 0xFB24, 0xFB64, 0xFBA4, 0xFBE4, 0xFC24, 0xFC64, + 0xFC94, 0xFCB4, 0xFCD4, 0xFCF4, 0xFD14, 0xFD34, 0xFD54, 0xFD74, 0xFD94, 0xFDB4, 0xFDD4, 0xFDF4, 0xFE14, 0xFE34, 0xFE54, 0xFE74, + 0xFE8C, 0xFE9C, 0xFEAC, 0xFEBC, 0xFECC, 0xFEDC, 0xFEEC, 0xFEFC, 0xFF0C, 0xFF1C, 0xFF2C, 0xFF3C, 0xFF4C, 0xFF5C, 0xFF6C, 0xFF7C, + 0xFF88, 0xFF90, 0xFF98, 0xFFA0, 0xFFA8, 0xFFB0, 0xFFB8, 0xFFC0, 0xFFC8, 0xFFD0, 0xFFD8, 0xFFE0, 0xFFE8, 0xFFF0, 0xFFF8, 0x0000, + 0x7D7C, 0x797C, 0x757C, 0x717C, 0x6D7C, 0x697C, 0x657C, 0x617C, 0x5D7C, 0x597C, 0x557C, 0x517C, 0x4D7C, 0x497C, 0x457C, 0x417C, + 0x3E7C, 0x3C7C, 0x3A7C, 0x387C, 0x367C, 0x347C, 0x327C, 0x307C, 0x2E7C, 0x2C7C, 0x2A7C, 0x287C, 0x267C, 0x247C, 0x227C, 0x207C, + 0x1EFC, 0x1DFC, 0x1CFC, 0x1BFC, 0x1AFC, 0x19FC, 0x18FC, 0x17FC, 0x16FC, 0x15FC, 0x14FC, 0x13FC, 0x12FC, 0x11FC, 0x10FC, 0x0FFC, + 0x0F3C, 0x0EBC, 0x0E3C, 0x0DBC, 0x0D3C, 0x0CBC, 0x0C3C, 0x0BBC, 0x0B3C, 0x0ABC, 0x0A3C, 0x09BC, 0x093C, 0x08BC, 0x083C, 0x07BC, + 0x075C, 0x071C, 0x06DC, 0x069C, 0x065C, 0x061C, 0x05DC, 0x059C, 0x055C, 0x051C, 0x04DC, 0x049C, 0x045C, 0x041C, 0x03DC, 0x039C, + 0x036C, 0x034C, 0x032C, 0x030C, 0x02EC, 0x02CC, 0x02AC, 0x028C, 0x026C, 0x024C, 0x022C, 0x020C, 0x01EC, 0x01CC, 0x01AC, 0x018C, + 0x0174, 0x0164, 0x0154, 0x0144, 0x0134, 0x0124, 0x0114, 0x0104, 0x00F4, 0x00E4, 0x00D4, 0x00C4, 0x00B4, 0x00A4, 0x0094, 0x0084, + 0x0078, 0x0070, 0x0068, 0x0060, 0x0058, 0x0050, 0x0048, 0x0040, 0x0038, 0x0030, 0x0028, 0x0020, 0x0018, 0x0010, 0x0008, 0x0000 +}; + +static DRWAV_INLINE drwav_int16 drwav__alaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavAlawTable[sampleIn]; +} + +static DRWAV_INLINE drwav_int16 drwav__mulaw_to_s16(drwav_uint8 sampleIn) +{ + return (short)g_drwavMulawTable[sampleIn]; +} + + + +static void drwav__pcm_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s16(pOut, pIn, totalSampleCount); + return; + } + + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int16*)pIn)[i]; + } + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s16(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_s16(pOut, (const drwav_int32*)pIn, totalSampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int16)((drwav_int64)sample >> 48); + } +} + +static void drwav__ieee_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s16(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s16(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + +static drwav_uint64 drwav_read_pcm_frames_s16__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint32 bytesPerFrame; + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + /* Fast path. */ + if ((pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 16) || pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s16__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s16(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int16) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int16) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s16__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s16__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s16__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s16__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s16__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s16__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16le(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s16be(drwav* pWav, drwav_uint64 framesToRead, drwav_int16* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s16(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x << 8; + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s24_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = ((int)(((unsigned int)(((const drwav_uint8*)pIn)[i*3+0]) << 8) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+1]) << 16) | ((unsigned int)(((const drwav_uint8*)pIn)[i*3+2])) << 24)) >> 8; + r = x >> 8; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_s32_to_s16(drwav_int16* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + int x = pIn[i]; + r = x >> 16; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f32_to_s16(drwav_int16* pOut, const float* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + float c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5f); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_f64_to_s16(drwav_int16* pOut, const double* pIn, size_t sampleCount) +{ + int r; + size_t i; + for (i = 0; i < sampleCount; ++i) { + double x = pIn[i]; + double c; + c = ((x < -1) ? -1 : ((x > 1) ? 1 : x)); + c = c + 1; + r = (int)(c * 32767.5); + r = r - 32768; + pOut[i] = (short)r; + } +} + +DRWAV_API void drwav_alaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__alaw_to_s16(pIn[i]); + } +} + +DRWAV_API void drwav_mulaw_to_s16(drwav_int16* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + for (i = 0; i < sampleCount; ++i) { + pOut[i] = drwav__mulaw_to_s16(pIn[i]); + } +} + + + +static void drwav__pcm_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_f32(pOut, pIn, sampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_f32(pOut, (const drwav_int16*)pIn, sampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_f32(pOut, pIn, sampleCount); + return; + } + if (bytesPerSample == 4) { + drwav_s32_to_f32(pOut, (const drwav_int32*)pIn, sampleCount); + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < sampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (float)((drwav_int64)sample / 9223372036854775807.0); + } +} + +static void drwav__ieee_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + unsigned int i; + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((const float*)pIn)[i]; + } + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_f32(pOut, (const double*)pIn, sampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, sampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_f32__pcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_f32(pBufferOut, sampleData, (size_t)framesRead*pWav->channels, bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ima(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_f32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__ieee(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__alaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_f32__mulaw(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_f32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(float) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(float) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_f32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_f32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_f32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_f32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_f32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_f32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32le(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_f32be(drwav* pWav, drwav_uint64 framesToRead, float* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_f32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_f32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + +#ifdef DR_WAV_LIBSNDFILE_COMPAT + /* + It appears libsndfile uses slightly different logic for the u8 -> f32 conversion to dr_wav, which in my opinion is incorrect. It appears + libsndfile performs the conversion something like "f32 = (u8 / 256) * 2 - 1", however I think it should be "f32 = (u8 / 255) * 2 - 1" (note + the divisor of 256 vs 255). I use libsndfile as a benchmark for testing, so I'm therefore leaving this block here just for my automated + correctness testing. This is disabled by default. + */ + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (pIn[i] / 256.0f) * 2 - 1; + } +#else + for (i = 0; i < sampleCount; ++i) { + float x = pIn[i]; + x = x * 0.00784313725490196078f; /* 0..255 to 0..2 */ + x = x - 1; /* 0..2 to -1..1 */ + + *pOut++ = x; + } +#endif +} + +DRWAV_API void drwav_s16_to_f32(float* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] * 0.000030517578125f; + } +} + +DRWAV_API void drwav_s24_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + double x; + drwav_uint32 a = ((drwav_uint32)(pIn[i*3+0]) << 8); + drwav_uint32 b = ((drwav_uint32)(pIn[i*3+1]) << 16); + drwav_uint32 c = ((drwav_uint32)(pIn[i*3+2]) << 24); + + x = (double)((drwav_int32)(a | b | c) >> 8); + *pOut++ = (float)(x * 0.00000011920928955078125); + } +} + +DRWAV_API void drwav_s32_to_f32(float* pOut, const drwav_int32* pIn, size_t sampleCount) +{ + size_t i; + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)(pIn[i] / 2147483648.0); + } +} + +DRWAV_API void drwav_f64_to_f32(float* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (float)pIn[i]; + } +} + +DRWAV_API void drwav_alaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__alaw_to_s16(pIn[i]) / 32768.0f; + } +} + +DRWAV_API void drwav_mulaw_to_f32(float* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = drwav__mulaw_to_s16(pIn[i]) / 32768.0f; + } +} + + + +static void drwav__pcm_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + unsigned int i; + + /* Special case for 8-bit sample data because it's treated as unsigned. */ + if (bytesPerSample == 1) { + drwav_u8_to_s32(pOut, pIn, totalSampleCount); + return; + } + + /* Slightly more optimal implementation for common formats. */ + if (bytesPerSample == 2) { + drwav_s16_to_s32(pOut, (const drwav_int16*)pIn, totalSampleCount); + return; + } + if (bytesPerSample == 3) { + drwav_s24_to_s32(pOut, pIn, totalSampleCount); + return; + } + if (bytesPerSample == 4) { + for (i = 0; i < totalSampleCount; ++i) { + *pOut++ = ((const drwav_int32*)pIn)[i]; + } + return; + } + + + /* Anything more than 64 bits per sample is not supported. */ + if (bytesPerSample > 8) { + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } + + + /* Generic, slow converter. */ + for (i = 0; i < totalSampleCount; ++i) { + drwav_uint64 sample = 0; + unsigned int shift = (8 - bytesPerSample) * 8; + + unsigned int j; + for (j = 0; j < bytesPerSample; j += 1) { + DRWAV_ASSERT(j < 8); + sample |= (drwav_uint64)(pIn[j]) << shift; + shift += 8; + } + + pIn += j; + *pOut++ = (drwav_int32)((drwav_int64)sample >> 32); + } +} + +static void drwav__ieee_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) +{ + if (bytesPerSample == 4) { + drwav_f32_to_s32(pOut, (const float*)pIn, totalSampleCount); + return; + } else if (bytesPerSample == 8) { + drwav_f64_to_s32(pOut, (const double*)pIn, totalSampleCount); + return; + } else { + /* Only supporting 32- and 64-bit float. Output silence in all other cases. Contributions welcome for 16-bit float. */ + DRWAV_ZERO_MEMORY(pOut, totalSampleCount * sizeof(*pOut)); + return; + } +} + + +static drwav_uint64 drwav_read_pcm_frames_s32__pcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + drwav_uint32 bytesPerFrame; + + /* Fast path. */ + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM && pWav->bitsPerSample == 32) { + return drwav_read_pcm_frames(pWav, framesToRead, pBufferOut); + } + + bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__pcm_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__msadpcm(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ima(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + /* + We're just going to borrow the implementation from the drwav_read_s16() since IMA-ADPCM is a little bit more complicated than other formats and I don't + want to duplicate that code. + */ + drwav_uint64 totalFramesRead = 0; + drwav_int16 samples16[2048]; + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames_s16(pWav, drwav_min(framesToRead, drwav_countof(samples16)/pWav->channels), samples16); + if (framesRead == 0) { + break; + } + + drwav_s16_to_s32(pBufferOut, samples16, (size_t)(framesRead*pWav->channels)); /* <-- Safe cast because we're clamping to 2048. */ + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__ieee(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav__ieee_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels), bytesPerFrame/pWav->channels); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__alaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_alaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +static drwav_uint64 drwav_read_pcm_frames_s32__mulaw(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 totalFramesRead; + drwav_uint8 sampleData[4096]; + + drwav_uint32 bytesPerFrame = drwav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame == 0) { + return 0; + } + + totalFramesRead = 0; + + while (framesToRead > 0) { + drwav_uint64 framesRead = drwav_read_pcm_frames(pWav, drwav_min(framesToRead, sizeof(sampleData)/bytesPerFrame), sampleData); + if (framesRead == 0) { + break; + } + + drwav_mulaw_to_s32(pBufferOut, sampleData, (size_t)(framesRead*pWav->channels)); + + pBufferOut += framesRead*pWav->channels; + framesToRead -= framesRead; + totalFramesRead += framesRead; + } + + return totalFramesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + if (pWav == NULL || framesToRead == 0) { + return 0; + } + + if (pBufferOut == NULL) { + return drwav_read_pcm_frames(pWav, framesToRead, NULL); + } + + /* Don't try to read more samples than can potentially fit in the output buffer. */ + if (framesToRead * pWav->channels * sizeof(drwav_int32) > DRWAV_SIZE_MAX) { + framesToRead = DRWAV_SIZE_MAX / sizeof(drwav_int32) / pWav->channels; + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_PCM) { + return drwav_read_pcm_frames_s32__pcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ADPCM) { + return drwav_read_pcm_frames_s32__msadpcm(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_IEEE_FLOAT) { + return drwav_read_pcm_frames_s32__ieee(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_ALAW) { + return drwav_read_pcm_frames_s32__alaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_MULAW) { + return drwav_read_pcm_frames_s32__mulaw(pWav, framesToRead, pBufferOut); + } + + if (pWav->translatedFormatTag == DR_WAVE_FORMAT_DVI_ADPCM) { + return drwav_read_pcm_frames_s32__ima(pWav, framesToRead, pBufferOut); + } + + return 0; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32le(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_FALSE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + +DRWAV_API drwav_uint64 drwav_read_pcm_frames_s32be(drwav* pWav, drwav_uint64 framesToRead, drwav_int32* pBufferOut) +{ + drwav_uint64 framesRead = drwav_read_pcm_frames_s32(pWav, framesToRead, pBufferOut); + if (pBufferOut != NULL && drwav__is_little_endian() == DRWAV_TRUE) { + drwav__bswap_samples_s32(pBufferOut, framesRead*pWav->channels); + } + + return framesRead; +} + + +DRWAV_API void drwav_u8_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((int)pIn[i] - 128) << 24; + } +} + +DRWAV_API void drwav_s16_to_s32(drwav_int32* pOut, const drwav_int16* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = pIn[i] << 16; + } +} + +DRWAV_API void drwav_s24_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + unsigned int s0 = pIn[i*3 + 0]; + unsigned int s1 = pIn[i*3 + 1]; + unsigned int s2 = pIn[i*3 + 2]; + + drwav_int32 sample32 = (drwav_int32)((s0 << 8) | (s1 << 16) | (s2 << 24)); + *pOut++ = sample32; + } +} + +DRWAV_API void drwav_f32_to_s32(drwav_int32* pOut, const float* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_f64_to_s32(drwav_int32* pOut, const double* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = (drwav_int32)(2147483648.0 * pIn[i]); + } +} + +DRWAV_API void drwav_alaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i = 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__alaw_to_s16(pIn[i])) << 16; + } +} + +DRWAV_API void drwav_mulaw_to_s32(drwav_int32* pOut, const drwav_uint8* pIn, size_t sampleCount) +{ + size_t i; + + if (pOut == NULL || pIn == NULL) { + return; + } + + for (i= 0; i < sampleCount; ++i) { + *pOut++ = ((drwav_int32)drwav__mulaw_to_s16(pIn[i])) << 16; + } +} + + + +static drwav_int16* drwav__read_pcm_frames_and_close_s16(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int16* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int16); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int16*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s16(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static float* drwav__read_pcm_frames_and_close_f32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + float* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (float*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_f32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + +static drwav_int32* drwav__read_pcm_frames_and_close_s32(drwav* pWav, unsigned int* channels, unsigned int* sampleRate, drwav_uint64* totalFrameCount) +{ + drwav_uint64 sampleDataSize; + drwav_int32* pSampleData; + drwav_uint64 framesRead; + + DRWAV_ASSERT(pWav != NULL); + + sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(drwav_int32); + if (sampleDataSize > DRWAV_SIZE_MAX) { + drwav_uninit(pWav); + return NULL; /* File's too big. */ + } + + pSampleData = (drwav_int32*)drwav__malloc_from_callbacks((size_t)sampleDataSize, &pWav->allocationCallbacks); /* <-- Safe cast due to the check above. */ + if (pSampleData == NULL) { + drwav_uninit(pWav); + return NULL; /* Failed to allocate memory. */ + } + + framesRead = drwav_read_pcm_frames_s32(pWav, (size_t)pWav->totalPCMFrameCount, pSampleData); + if (framesRead != pWav->totalPCMFrameCount) { + drwav__free_from_callbacks(pSampleData, &pWav->allocationCallbacks); + drwav_uninit(pWav); + return NULL; /* There was an error reading the samples. */ + } + + drwav_uninit(pWav); + + if (sampleRate) { + *sampleRate = pWav->sampleRate; + } + if (channels) { + *channels = pWav->channels; + } + if (totalFrameCount) { + *totalFrameCount = pWav->totalPCMFrameCount; + } + + return pSampleData; +} + + + +DRWAV_API drwav_int16* drwav_open_and_read_pcm_frames_s16(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_and_read_pcm_frames_f32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_and_read_pcm_frames_s32(drwav_read_proc onRead, drwav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +#ifndef DR_WAV_NO_STDIO +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + + +DRWAV_API drwav_int16* drwav_open_file_and_read_pcm_frames_s16_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_file_and_read_pcm_frames_f32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_file_and_read_pcm_frames_s32_w(const wchar_t* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (channelsOut) { + *channelsOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_file_w(&wav, filename, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif + +DRWAV_API drwav_int16* drwav_open_memory_and_read_pcm_frames_s16(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API float* drwav_open_memory_and_read_pcm_frames_f32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} + +DRWAV_API drwav_int32* drwav_open_memory_and_read_pcm_frames_s32(const void* data, size_t dataSize, unsigned int* channelsOut, unsigned int* sampleRateOut, drwav_uint64* totalFrameCountOut, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + drwav wav; + + if (channelsOut) { + *channelsOut = 0; + } + if (sampleRateOut) { + *sampleRateOut = 0; + } + if (totalFrameCountOut) { + *totalFrameCountOut = 0; + } + + if (!drwav_init_memory(&wav, data, dataSize, pAllocationCallbacks)) { + return NULL; + } + + return drwav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); +} +#endif /* DR_WAV_NO_CONVERSION_API */ + + +DRWAV_API void drwav_free(void* p, const drwav_allocation_callbacks* pAllocationCallbacks) +{ + if (pAllocationCallbacks != NULL) { + drwav__free_from_callbacks(p, pAllocationCallbacks); + } else { + drwav__free_default(p, NULL); + } +} + +DRWAV_API drwav_uint16 drwav_bytes_to_u16(const drwav_uint8* data) +{ + return drwav__bytes_to_u16(data); +} + +DRWAV_API drwav_int16 drwav_bytes_to_s16(const drwav_uint8* data) +{ + return drwav__bytes_to_s16(data); +} + +DRWAV_API drwav_uint32 drwav_bytes_to_u32(const drwav_uint8* data) +{ + return drwav__bytes_to_u32(data); +} + +DRWAV_API drwav_int32 drwav_bytes_to_s32(const drwav_uint8* data) +{ + return drwav__bytes_to_s32(data); +} + +DRWAV_API drwav_uint64 drwav_bytes_to_u64(const drwav_uint8* data) +{ + return drwav__bytes_to_u64(data); +} + +DRWAV_API drwav_int64 drwav_bytes_to_s64(const drwav_uint8* data) +{ + return drwav__bytes_to_s64(data); +} + + +DRWAV_API drwav_bool32 drwav_guid_equal(const drwav_uint8 a[16], const drwav_uint8 b[16]) +{ + return drwav__guid_equal(a, b); +} + +DRWAV_API drwav_bool32 drwav_fourcc_equal(const drwav_uint8* a, const char* b) +{ + return drwav__fourcc_equal(a, b); +} + +#endif /* dr_wav_c */ +#endif /* DR_WAV_IMPLEMENTATION */ + +/* +RELEASE NOTES - v0.11.0 +======================= +Version 0.11.0 has breaking API changes. + +Improved Client-Defined Memory Allocation +----------------------------------------- +The main change with this release is the addition of a more flexible way of implementing custom memory allocation routines. The +existing system of DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE are still in place and will be used by default when no custom +allocation callbacks are specified. + +To use the new system, you pass in a pointer to a drwav_allocation_callbacks object to drwav_init() and family, like this: + + void* my_malloc(size_t sz, void* pUserData) + { + return malloc(sz); + } + void* my_realloc(void* p, size_t sz, void* pUserData) + { + return realloc(p, sz); + } + void my_free(void* p, void* pUserData) + { + free(p); + } + + ... + + drwav_allocation_callbacks allocationCallbacks; + allocationCallbacks.pUserData = &myData; + allocationCallbacks.onMalloc = my_malloc; + allocationCallbacks.onRealloc = my_realloc; + allocationCallbacks.onFree = my_free; + drwav_init_file(&wav, "my_file.wav", &allocationCallbacks); + +The advantage of this new system is that it allows you to specify user data which will be passed in to the allocation routines. + +Passing in null for the allocation callbacks object will cause dr_wav to use defaults which is the same as DRWAV_MALLOC, +DRWAV_REALLOC and DRWAV_FREE and the equivalent of how it worked in previous versions. + +Every API that opens a drwav object now takes this extra parameter. These include the following: + + drwav_init() + drwav_init_ex() + drwav_init_file() + drwav_init_file_ex() + drwav_init_file_w() + drwav_init_file_w_ex() + drwav_init_memory() + drwav_init_memory_ex() + drwav_init_write() + drwav_init_write_sequential() + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write() + drwav_init_file_write_sequential() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write() + drwav_init_memory_write_sequential() + drwav_init_memory_write_sequential_pcm_frames() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_pcm_frames_s32() + +Endian Improvements +------------------- +Previously, the following APIs returned little-endian audio data. These now return native-endian data. This improves compatibility +on big-endian architectures. + + drwav_read_pcm_frames() + drwav_read_pcm_frames_s16() + drwav_read_pcm_frames_s32() + drwav_read_pcm_frames_f32() + drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_pcm_frames_s32() + drwav_open_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_pcm_frames_s32() + drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_pcm_frames_s16_w() + drwav_open_file_and_read_pcm_frames_s32_w() + drwav_open_file_and_read_pcm_frames_f32_w() + drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_pcm_frames_s32() + drwav_open_memory_and_read_pcm_frames_f32() + +APIs have been added to give you explicit control over whether or not audio data is read or written in big- or little-endian byte +order: + + drwav_read_pcm_frames_le() + drwav_read_pcm_frames_be() + drwav_read_pcm_frames_s16le() + drwav_read_pcm_frames_s16be() + drwav_read_pcm_frames_f32le() + drwav_read_pcm_frames_f32be() + drwav_read_pcm_frames_s32le() + drwav_read_pcm_frames_s32be() + drwav_write_pcm_frames_le() + drwav_write_pcm_frames_be() + +Removed APIs +------------ +The following APIs were deprecated in version 0.10.0 and have now been removed: + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + + + +RELEASE NOTES - v0.10.0 +======================= +Version 0.10.0 has breaking API changes. There are no significant bug fixes in this release, so if you are affected you do +not need to upgrade. + +Removed APIs +------------ +The following APIs were deprecated in version 0.9.0 and have been completely removed in version 0.10.0: + + drwav_read() + drwav_read_s16() + drwav_read_f32() + drwav_read_s32() + drwav_seek_to_sample() + drwav_write() + drwav_open_and_read_s16() + drwav_open_and_read_f32() + drwav_open_and_read_s32() + drwav_open_file_and_read_s16() + drwav_open_file_and_read_f32() + drwav_open_file_and_read_s32() + drwav_open_memory_and_read_s16() + drwav_open_memory_and_read_f32() + drwav_open_memory_and_read_s32() + drwav::totalSampleCount + +See release notes for version 0.9.0 at the bottom of this file for replacement APIs. + +Deprecated APIs +--------------- +The following APIs have been deprecated. There is a confusing and completely arbitrary difference between drwav_init*() and +drwav_open*(), where drwav_init*() initializes a pre-allocated drwav object, whereas drwav_open*() will first allocated a +drwav object on the heap and then initialize it. drwav_open*() has been deprecated which means you must now use a pre- +allocated drwav object with drwav_init*(). If you need the previous functionality, you can just do a malloc() followed by +a called to one of the drwav_init*() APIs. + + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + +These APIs will be removed completely in a future version. The rationale for this change is to remove confusion between the +two different ways to initialize a drwav object. +*/ + +/* +REVISION HISTORY +================ +v0.12.16 - 2020-12-02 + - Fix a bug when trying to read more bytes than can fit in a size_t. + +v0.12.15 - 2020-11-21 + - Fix compilation with OpenWatcom. + +v0.12.14 - 2020-11-13 + - Minor code clean up. + +v0.12.13 - 2020-11-01 + - Improve compiler support for older versions of GCC. + +v0.12.12 - 2020-09-28 + - Add support for RF64. + - Fix a bug in writing mode where the size of the RIFF chunk incorrectly includes the header section. + +v0.12.11 - 2020-09-08 + - Fix a compilation error on older compilers. + +v0.12.10 - 2020-08-24 + - Fix a bug when seeking with ADPCM formats. + +v0.12.9 - 2020-08-02 + - Simplify sized types. + +v0.12.8 - 2020-07-25 + - Fix a compilation warning. + +v0.12.7 - 2020-07-15 + - Fix some bugs on big-endian architectures. + - Fix an error in s24 to f32 conversion. + +v0.12.6 - 2020-06-23 + - Change drwav_read_*() to allow NULL to be passed in as the output buffer which is equivalent to a forward seek. + - Fix a buffer overflow when trying to decode invalid IMA-ADPCM files. + - Add include guard for the implementation section. + +v0.12.5 - 2020-05-27 + - Minor documentation fix. + +v0.12.4 - 2020-05-16 + - Replace assert() with DRWAV_ASSERT(). + - Add compile-time and run-time version querying. + - DRWAV_VERSION_MINOR + - DRWAV_VERSION_MAJOR + - DRWAV_VERSION_REVISION + - DRWAV_VERSION_STRING + - drwav_version() + - drwav_version_string() + +v0.12.3 - 2020-04-30 + - Fix compilation errors with VC6. + +v0.12.2 - 2020-04-21 + - Fix a bug where drwav_init_file() does not close the file handle after attempting to load an erroneous file. + +v0.12.1 - 2020-04-13 + - Fix some pedantic warnings. + +v0.12.0 - 2020-04-04 + - API CHANGE: Add container and format parameters to the chunk callback. + - Minor documentation updates. + +v0.11.5 - 2020-03-07 + - Fix compilation error with Visual Studio .NET 2003. + +v0.11.4 - 2020-01-29 + - Fix some static analysis warnings. + - Fix a bug when reading f32 samples from an A-law encoded stream. + +v0.11.3 - 2020-01-12 + - Minor changes to some f32 format conversion routines. + - Minor bug fix for ADPCM conversion when end of file is reached. + +v0.11.2 - 2019-12-02 + - Fix a possible crash when using custom memory allocators without a custom realloc() implementation. + - Fix an integer overflow bug. + - Fix a null pointer dereference bug. + - Add limits to sample rate, channels and bits per sample to tighten up some validation. + +v0.11.1 - 2019-10-07 + - Internal code clean up. + +v0.11.0 - 2019-10-06 + - API CHANGE: Add support for user defined memory allocation routines. This system allows the program to specify their own memory allocation + routines with a user data pointer for client-specific contextual data. This adds an extra parameter to the end of the following APIs: + - drwav_init() + - drwav_init_ex() + - drwav_init_file() + - drwav_init_file_ex() + - drwav_init_file_w() + - drwav_init_file_w_ex() + - drwav_init_memory() + - drwav_init_memory_ex() + - drwav_init_write() + - drwav_init_write_sequential() + - drwav_init_write_sequential_pcm_frames() + - drwav_init_file_write() + - drwav_init_file_write_sequential() + - drwav_init_file_write_sequential_pcm_frames() + - drwav_init_file_write_w() + - drwav_init_file_write_sequential_w() + - drwav_init_file_write_sequential_pcm_frames_w() + - drwav_init_memory_write() + - drwav_init_memory_write_sequential() + - drwav_init_memory_write_sequential_pcm_frames() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_f32() + - drwav_open_memory_and_read_pcm_frames_s32() + Set this extra parameter to NULL to use defaults which is the same as the previous behaviour. Setting this NULL will use + DRWAV_MALLOC, DRWAV_REALLOC and DRWAV_FREE. + - Add support for reading and writing PCM frames in an explicit endianness. New APIs: + - drwav_read_pcm_frames_le() + - drwav_read_pcm_frames_be() + - drwav_read_pcm_frames_s16le() + - drwav_read_pcm_frames_s16be() + - drwav_read_pcm_frames_f32le() + - drwav_read_pcm_frames_f32be() + - drwav_read_pcm_frames_s32le() + - drwav_read_pcm_frames_s32be() + - drwav_write_pcm_frames_le() + - drwav_write_pcm_frames_be() + - Remove deprecated APIs. + - API CHANGE: The following APIs now return native-endian data. Previously they returned little-endian data. + - drwav_read_pcm_frames() + - drwav_read_pcm_frames_s16() + - drwav_read_pcm_frames_s32() + - drwav_read_pcm_frames_f32() + - drwav_open_and_read_pcm_frames_s16() + - drwav_open_and_read_pcm_frames_s32() + - drwav_open_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16() + - drwav_open_file_and_read_pcm_frames_s32() + - drwav_open_file_and_read_pcm_frames_f32() + - drwav_open_file_and_read_pcm_frames_s16_w() + - drwav_open_file_and_read_pcm_frames_s32_w() + - drwav_open_file_and_read_pcm_frames_f32_w() + - drwav_open_memory_and_read_pcm_frames_s16() + - drwav_open_memory_and_read_pcm_frames_s32() + - drwav_open_memory_and_read_pcm_frames_f32() + +v0.10.1 - 2019-08-31 + - Correctly handle partial trailing ADPCM blocks. + +v0.10.0 - 2019-08-04 + - Remove deprecated APIs. + - Add wchar_t variants for file loading APIs: + drwav_init_file_w() + drwav_init_file_ex_w() + drwav_init_file_write_w() + drwav_init_file_write_sequential_w() + - Add drwav_target_write_size_bytes() which calculates the total size in bytes of a WAV file given a format and sample count. + - Add APIs for specifying the PCM frame count instead of the sample count when opening in sequential write mode: + drwav_init_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames() + drwav_init_file_write_sequential_pcm_frames_w() + drwav_init_memory_write_sequential_pcm_frames() + - Deprecate drwav_open*() and drwav_close(): + drwav_open() + drwav_open_ex() + drwav_open_write() + drwav_open_write_sequential() + drwav_open_file() + drwav_open_file_ex() + drwav_open_file_write() + drwav_open_file_write_sequential() + drwav_open_memory() + drwav_open_memory_ex() + drwav_open_memory_write() + drwav_open_memory_write_sequential() + drwav_close() + - Minor documentation updates. + +v0.9.2 - 2019-05-21 + - Fix warnings. + +v0.9.1 - 2019-05-05 + - Add support for C89. + - Change license to choice of public domain or MIT-0. + +v0.9.0 - 2018-12-16 + - API CHANGE: Add new reading APIs for reading by PCM frames instead of samples. Old APIs have been deprecated and + will be removed in v0.10.0. Deprecated APIs and their replacements: + drwav_read() -> drwav_read_pcm_frames() + drwav_read_s16() -> drwav_read_pcm_frames_s16() + drwav_read_f32() -> drwav_read_pcm_frames_f32() + drwav_read_s32() -> drwav_read_pcm_frames_s32() + drwav_seek_to_sample() -> drwav_seek_to_pcm_frame() + drwav_write() -> drwav_write_pcm_frames() + drwav_open_and_read_s16() -> drwav_open_and_read_pcm_frames_s16() + drwav_open_and_read_f32() -> drwav_open_and_read_pcm_frames_f32() + drwav_open_and_read_s32() -> drwav_open_and_read_pcm_frames_s32() + drwav_open_file_and_read_s16() -> drwav_open_file_and_read_pcm_frames_s16() + drwav_open_file_and_read_f32() -> drwav_open_file_and_read_pcm_frames_f32() + drwav_open_file_and_read_s32() -> drwav_open_file_and_read_pcm_frames_s32() + drwav_open_memory_and_read_s16() -> drwav_open_memory_and_read_pcm_frames_s16() + drwav_open_memory_and_read_f32() -> drwav_open_memory_and_read_pcm_frames_f32() + drwav_open_memory_and_read_s32() -> drwav_open_memory_and_read_pcm_frames_s32() + drwav::totalSampleCount -> drwav::totalPCMFrameCount + - API CHANGE: Rename drwav_open_and_read_file_*() to drwav_open_file_and_read_*(). + - API CHANGE: Rename drwav_open_and_read_memory_*() to drwav_open_memory_and_read_*(). + - Add built-in support for smpl chunks. + - Add support for firing a callback for each chunk in the file at initialization time. + - This is enabled through the drwav_init_ex(), etc. family of APIs. + - Handle invalid FMT chunks more robustly. + +v0.8.5 - 2018-09-11 + - Const correctness. + - Fix a potential stack overflow. + +v0.8.4 - 2018-08-07 + - Improve 64-bit detection. + +v0.8.3 - 2018-08-05 + - Fix C++ build on older versions of GCC. + +v0.8.2 - 2018-08-02 + - Fix some big-endian bugs. + +v0.8.1 - 2018-06-29 + - Add support for sequential writing APIs. + - Disable seeking in write mode. + - Fix bugs with Wave64. + - Fix typos. + +v0.8 - 2018-04-27 + - Bug fix. + - Start using major.minor.revision versioning. + +v0.7f - 2018-02-05 + - Restrict ADPCM formats to a maximum of 2 channels. + +v0.7e - 2018-02-02 + - Fix a crash. + +v0.7d - 2018-02-01 + - Fix a crash. + +v0.7c - 2018-02-01 + - Set drwav.bytesPerSample to 0 for all compressed formats. + - Fix a crash when reading 16-bit floating point WAV files. In this case dr_wav will output silence for + all format conversion reading APIs (*_s16, *_s32, *_f32 APIs). + - Fix some divide-by-zero errors. + +v0.7b - 2018-01-22 + - Fix errors with seeking of compressed formats. + - Fix compilation error when DR_WAV_NO_CONVERSION_API + +v0.7a - 2017-11-17 + - Fix some GCC warnings. + +v0.7 - 2017-11-04 + - Add writing APIs. + +v0.6 - 2017-08-16 + - API CHANGE: Rename dr_* types to drwav_*. + - Add support for custom implementations of malloc(), realloc(), etc. + - Add support for Microsoft ADPCM. + - Add support for IMA ADPCM (DVI, format code 0x11). + - Optimizations to drwav_read_s16(). + - Bug fixes. + +v0.5g - 2017-07-16 + - Change underlying type for booleans to unsigned. + +v0.5f - 2017-04-04 + - Fix a minor bug with drwav_open_and_read_s16() and family. + +v0.5e - 2016-12-29 + - Added support for reading samples as signed 16-bit integers. Use the _s16() family of APIs for this. + - Minor fixes to documentation. + +v0.5d - 2016-12-28 + - Use drwav_int* and drwav_uint* sized types to improve compiler support. + +v0.5c - 2016-11-11 + - Properly handle JUNK chunks that come before the FMT chunk. + +v0.5b - 2016-10-23 + - A minor change to drwav_bool8 and drwav_bool32 types. + +v0.5a - 2016-10-11 + - Fixed a bug with drwav_open_and_read() and family due to incorrect argument ordering. + - Improve A-law and mu-law efficiency. + +v0.5 - 2016-09-29 + - API CHANGE. Swap the order of "channels" and "sampleRate" parameters in drwav_open_and_read*(). Rationale for this is to + keep it consistent with dr_audio and dr_flac. + +v0.4b - 2016-09-18 + - Fixed a typo in documentation. + +v0.4a - 2016-09-18 + - Fixed a typo. + - Change date format to ISO 8601 (YYYY-MM-DD) + +v0.4 - 2016-07-13 + - API CHANGE. Make onSeek consistent with dr_flac. + - API CHANGE. Rename drwav_seek() to drwav_seek_to_sample() for clarity and consistency with dr_flac. + - Added support for Sony Wave64. + +v0.3a - 2016-05-28 + - API CHANGE. Return drwav_bool32 instead of int in onSeek callback. + - Fixed a memory leak. + +v0.3 - 2016-05-22 + - Lots of API changes for consistency. + +v0.2a - 2016-05-16 + - Fixed Linux/GCC build. + +v0.2 - 2016-05-11 + - Added support for reading data as signed 32-bit PCM for consistency with dr_flac. + +v0.1a - 2016-05-07 + - Fixed a bug in drwav_open_file() where the file handle would not be closed if the loader failed to initialize. + +v0.1 - 2016-05-04 + - Initial versioned release. +*/ + +/* +This software is available as a choice of the following licenses. Choose +whichever you prefer. + +=============================================================================== +ALTERNATIVE 1 - Public Domain (www.unlicense.org) +=============================================================================== +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or distribute this +software, either in source code form or as a compiled binary, for any purpose, +commercial or non-commercial, and by any means. + +In jurisdictions that recognize copyright laws, the author or authors of this +software dedicate any and all copyright interest in the software to the public +domain. We make this dedication for the benefit of the public at large and to +the detriment of our heirs and successors. We intend this dedication to be an +overt act of relinquishment in perpetuity of all present and future rights to +this software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to + +=============================================================================== +ALTERNATIVE 2 - MIT No Attribution +=============================================================================== +Copyright 2020 David Reid + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +of the Software, and to permit persons to whom the Software is furnished to do +so. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ \ No newline at end of file diff --git a/otherarch/whispercpp/main.cpp b/otherarch/whispercpp/main.cpp new file mode 100644 index 000000000..a201cce6a --- /dev/null +++ b/otherarch/whispercpp/main.cpp @@ -0,0 +1,1374 @@ +#include "common.h" + +#include "whisper.h" +#include "grammar-parser.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#define COMMON_SAMPLE_RATE 16000 + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +// helper function to replace substrings +static void l_replace_all(std::string & s, const std::string & search, const std::string & replace) { + for (size_t pos = 0; ; pos += replace.length()) { + pos = s.find(search, pos); + if (pos == std::string::npos) break; + s.erase(pos, search.length()); + s.insert(pos, replace); + } +} + +static bool is_wav_buffer(const std::string buf) { + // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format + // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html + if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") { + return false; + } + + uint32_t chunk_size = *reinterpret_cast(buf.data() + 4); + if (chunk_size + 8 != buf.size()) { + return false; + } + + return true; +} + +static bool read_wav(const std::string & fname, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) { + drwav wav; + std::vector wav_data; // used for pipe input from stdin or ffmpeg decoding output + + if (fname == "-") { + // { + // #ifdef _WIN32 + // _setmode(_fileno(stdin), _O_BINARY); + // #endif + + // uint8_t buf[1024]; + // while (true) + // { + // const size_t n = fread(buf, 1, sizeof(buf), stdin); + // if (n == 0) { + // break; + // } + // wav_data.insert(wav_data.end(), buf, buf + n); + // } + // } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return false; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } + else if (is_wav_buffer(fname)) { + if (drwav_init_memory(&wav, fname.c_str(), fname.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from fname buffer\n"); + return false; + } + } + else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) { +#if defined(WHISPER_FFMPEG) + if (ffmpeg_decode_audio(fname, wav_data) != 0) { + fprintf(stderr, "error: failed to ffmpeg decode '%s' \n", fname.c_str()); + return false; + } + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to read wav data as wav \n"); + return false; + } +#else + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str()); + return false; +#endif + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + if (stereo && wav.channels != 2) { + fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + if (wav.sampleRate != COMMON_SAMPLE_RATE) { + fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000); + drwav_uninit(&wav); + return false; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str()); + drwav_uninit(&wav); + return false; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (stereo) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + + return true; +} + +// command-line parameters +struct whisper_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t n_processors = 1; + int32_t offset_t_ms = 0; + int32_t offset_n = 0; + int32_t duration_ms = 0; + int32_t progress_step = 5; + int32_t max_context = -1; + int32_t max_len = 0; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; + int32_t audio_ctx = 0; + + float word_thold = 0.01f; + float entropy_thold = 2.40f; + float logprob_thold = -1.00f; + float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; + + bool speed_up = false; + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool output_txt = false; + bool output_vtt = false; + bool output_srt = false; + bool output_wts = false; + bool output_csv = false; + bool output_jsn = false; + bool output_jsn_full = false; + bool output_lrc = false; + bool no_prints = false; + bool print_special = false; + bool print_colors = false; + bool print_progress = false; + bool no_timestamps = false; + bool log_score = false; + bool use_gpu = true; + bool flash_attn = false; + + std::string language = "en"; + std::string prompt; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string grammar; + std::string grammar_rule; + + // [TDRZ] speaker turn string + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line + + // A regular expression that matches tokens to suppress + std::string suppress_regex; + + std::string openvino_encode_device = "CPU"; + + std::string dtw = ""; + + std::vector fname_inp = {}; + std::vector fname_out = {}; + + grammar_parser::parse_state grammar_parsed; +}; + +void whisper_print_usage(int argc, char ** argv, const whisper_params & params); + +char* whisper_param_turn_lowercase(char* in){ + int string_len = strlen(in); + for(int i = 0; i < string_len; i++){ + *(in+i) = tolower((unsigned char)*(in+i)); + } + return in; +} + +bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } + // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; } + else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; } + else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; } + else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; } + else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; } + else if (arg == "-ojf" || arg == "--output-json-full"){ params.output_jsn_full = params.output_jsn = true; } + else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = whisper_param_turn_lowercase(argv[++i]); } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; } + else if ( arg == "--grammar") { params.grammar = argv[++i]; } + else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0.wav file1.wav ...\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); + // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false"); + fprintf(stderr, " -osrt, --output-srt [%-7s] output result in a srt file\n", params.output_srt ? "true" : "false"); + fprintf(stderr, " -olrc, --output-lrc [%-7s] output result in a lrc file\n", params.output_lrc ? "true" : "false"); + fprintf(stderr, " -owts, --output-words [%-7s] output script for generating karaoke video\n", params.output_wts ? "true" : "false"); + fprintf(stderr, " -fp, --font-path [%-7s] path to a monospace font for karaoke video\n", params.font_path.c_str()); + fprintf(stderr, " -ocsv, --output-csv [%-7s] output result in a CSV file\n", params.output_csv ? "true" : "false"); + fprintf(stderr, " -oj, --output-json [%-7s] output result in a JSON file\n", params.output_jsn ? "true" : "false"); + fprintf(stderr, " -ojf, --output-json-full [%-7s] include more information in the JSON file\n", params.output_jsn_full ? "true" : "false"); + fprintf(stderr, " -of FNAME, --output-file FNAME [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.prompt.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] input WAV file path\n", ""); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, "\n"); +} + +struct whisper_print_user_data { + const whisper_params * params; + + const std::vector> * pcmf32s; + int progress_prev; +}; + +std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { + std::string speaker = ""; + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples, WHISPER_SAMPLE_RATE); + const int64_t is1 = timestamp_to_sample(t1, n_samples, WHISPER_SAMPLE_RATE); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "0"; + } else if (energy1 > 1.1*energy0) { + speaker = "1"; + } else { + speaker = "?"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str()); + + if (!id_only) { + speaker.insert(0, "(speaker "); + speaker.append(")"); + } + + return speaker; +} +void whisper_print_progress_callback(struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { + int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step; + int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } +} + +void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { + const auto & params = *((whisper_print_user_data *) user_data)->params; + const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; + + const int n_segments = whisper_full_n_segments(ctx); + + std::string speaker = ""; + + int64_t t0 = 0; + int64_t t1 = 0; + + // print the last n_new segments + const int s0 = n_segments - n_new; + + if (s0 == 0) { + printf("\n"); + } + + for (int i = s0; i < n_segments; i++) { + if (!params.no_timestamps || params.diarize) { + t0 = whisper_full_get_segment_t0(ctx, i); + t1 = whisper_full_get_segment_t1(ctx, i); + } + + if (!params.no_timestamps) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + } + + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + if (params.print_colors) { + // for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + // if (params.print_special == false) { + // const whisper_token id = whisper_full_get_token_id(ctx, i, j); + // if (id >= whisper_token_eot(ctx)) { + // continue; + // } + // } + + // const char * text = whisper_full_get_token_text(ctx, i, j); + // const float p = whisper_full_get_token_p (ctx, i, j); + + // const int col = std::max(0, std::min((int) k_colors.size() - 1, (int) (std::pow(p, 3)*float(k_colors.size())))); + + // printf("%s%s%s%s", speaker.c_str(), k_colors[col].c_str(), text, "\033[0m"); + // } + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + + printf("%s%s", speaker.c_str(), text); + } + + if (params.tinydiarize) { + if (whisper_full_get_segment_speaker_turn_next(ctx, i)) { + printf("%s", params.tdrz_speaker_turn.c_str()); + } + } + + // with timestamps or speakers: each segment on new line + if (!params.no_timestamps || params.diarize) { + printf("\n"); + } + + fflush(stdout); + } +} + +bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + fout << speaker << text << "\n"; + } + + return true; +} + +bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + fout << "WEBVTT\n\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); + speaker.insert(0, ""); + } + + fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; + fout << speaker << text << "\n\n"; + } + + return true; +} + +bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + fout << i + 1 + params.offset_n << "\n"; + fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; + fout << speaker << text << "\n\n"; + } + + return true; +} + +char *escape_double_quotes_and_backslashes(const char *str) { + if (str == NULL) { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"' || str[i] == '\\') { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"' || str[i] == '\\') { + escaped[pos++] = '\\'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + +// double quote should be escaped by another double quote. (rfc4180) +char *escape_double_quotes_in_csv(const char *str) { + if (str == NULL) { + return NULL; + } + + size_t escaped_length = strlen(str) + 1; + + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped_length++; + } + } + + char *escaped = (char *)calloc(escaped_length, 1); // pre-zeroed + if (escaped == NULL) { + return NULL; + } + + size_t pos = 0; + for (size_t i = 0; str[i] != '\0'; i++) { + if (str[i] == '"') { + escaped[pos++] = '"'; + } + escaped[pos++] = str[i]; + } + + // no need to set zero due to calloc() being used prior + + return escaped; +} + +bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + fout << "start,end,"; + if (params.diarize && pcmf32s.size() == 2) + { + fout << "speaker,"; + } + fout << "text\n"; + + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + char * text_escaped = escape_double_quotes_in_csv(text); + + //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. + fout << 10 * t0 << "," << 10 * t1 << ","; + if (params.diarize && pcmf32s.size() == 2) + { + fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ","; + } + fout << "\"" << text_escaped << "\"\n"; + } + + return true; +} + +bool output_score(struct whisper_context * ctx, const char * fname, const whisper_params & /*params*/, std::vector> /*pcmf32s*/) { + std::ofstream fout(fname); + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + const int n_segments = whisper_full_n_segments(ctx); + // fprintf(stderr,"segments: %d\n",n_segments); + for (int i = 0; i < n_segments; ++i) { + const int n_tokens = whisper_full_n_tokens(ctx, i); + // fprintf(stderr,"tokens: %d\n",n_tokens); + for (int j = 0; j < n_tokens; j++) { + auto token = whisper_full_get_token_text(ctx, i, j); + auto probability = whisper_full_get_token_p(ctx, i, j); + fout << token << '\t' << probability << std::endl; + // fprintf(stderr,"token: %s %f\n",token,probability); + } + } + return true; +} + +bool output_json( + struct whisper_context * ctx, + const char * fname, + const whisper_params & params, + std::vector> pcmf32s, + bool full) { + std::ofstream fout(fname); + int indent = 0; + + auto doindent = [&]() { + for (int i = 0; i < indent; i++) fout << "\t"; + }; + + auto start_arr = [&](const char *name) { + doindent(); + fout << "\"" << name << "\": [\n"; + indent++; + }; + + auto end_arr = [&](bool end) { + indent--; + doindent(); + fout << (end ? "]\n" : "],\n"); + }; + + auto start_obj = [&](const char *name) { + doindent(); + if (name) { + fout << "\"" << name << "\": {\n"; + } else { + fout << "{\n"; + } + indent++; + }; + + auto end_obj = [&](bool end) { + indent--; + doindent(); + fout << (end ? "}\n" : "},\n"); + }; + + auto start_value = [&](const char *name) { + doindent(); + fout << "\"" << name << "\": "; + }; + + auto value_s = [&](const char *name, const char *val, bool end) { + start_value(name); + char * val_escaped = escape_double_quotes_and_backslashes(val); + fout << "\"" << val_escaped << (end ? "\"\n" : "\",\n"); + free(val_escaped); + }; + + auto end_value = [&](bool end) { + fout << (end ? "\n" : ",\n"); + }; + + auto value_i = [&](const char *name, const int64_t val, bool end) { + start_value(name); + fout << val; + end_value(end); + }; + + auto value_f = [&](const char *name, const float val, bool end) { + start_value(name); + fout << val; + end_value(end); + }; + + auto value_b = [&](const char *name, const bool val, bool end) { + start_value(name); + fout << (val ? "true" : "false"); + end_value(end); + }; + + auto times_o = [&](int64_t t0, int64_t t1, bool end) { + start_obj("timestamps"); + value_s("from", to_timestamp(t0, true).c_str(), false); + value_s("to", to_timestamp(t1, true).c_str(), true); + end_obj(false); + start_obj("offsets"); + value_i("from", t0 * 10, false); + value_i("to", t1 * 10, true); + end_obj(end); + }; + + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + start_obj(nullptr); + value_s("systeminfo", whisper_print_system_info(), false); + start_obj("model"); + value_s("type", whisper_model_type_readable(ctx), false); + value_b("multilingual", whisper_is_multilingual(ctx), false); + value_i("vocab", whisper_model_n_vocab(ctx), false); + start_obj("audio"); + value_i("ctx", whisper_model_n_audio_ctx(ctx), false); + value_i("state", whisper_model_n_audio_state(ctx), false); + value_i("head", whisper_model_n_audio_head(ctx), false); + value_i("layer", whisper_model_n_audio_layer(ctx), true); + end_obj(false); + start_obj("text"); + value_i("ctx", whisper_model_n_text_ctx(ctx), false); + value_i("state", whisper_model_n_text_state(ctx), false); + value_i("head", whisper_model_n_text_head(ctx), false); + value_i("layer", whisper_model_n_text_layer(ctx), true); + end_obj(false); + value_i("mels", whisper_model_n_mels(ctx), false); + value_i("ftype", whisper_model_ftype(ctx), true); + end_obj(false); + start_obj("params"); + value_s("model", params.model.c_str(), false); + value_s("language", params.language.c_str(), false); + value_b("translate", params.translate, true); + end_obj(false); + start_obj("result"); + value_s("language", whisper_lang_str(whisper_full_lang_id(ctx)), true); + end_obj(false); + start_arr("transcription"); + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + start_obj(nullptr); + times_o(t0, t1, false); + value_s("text", text, !params.diarize && !params.tinydiarize && !full); + + if (full) { + start_arr("tokens"); + const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { + auto token = whisper_full_get_token_data(ctx, i, j); + start_obj(nullptr); + value_s("text", whisper_token_to_str(ctx, token.id), false); + if(token.t0 > -1 && token.t1 > -1) { + // If we have per-token timestamps, write them out + times_o(token.t0, token.t1, false); + } + value_i("id", token.id, false); + value_f("p", token.p, false); + value_f("t_dtw", token.t_dtw, true); + end_obj(j == (n - 1)); + } + end_arr(!params.diarize && !params.tinydiarize); + } + + if (params.diarize && pcmf32s.size() == 2) { + value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); + } + + if (params.tinydiarize) { + value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true); + } + end_obj(i == (n_segments - 1)); + } + + end_arr(true); + end_obj(true); + return true; +} + +// karaoke video generation +// outputs a bash script that uses ffmpeg to generate a video with the subtitles +// TODO: font parameter adjustments +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector> pcmf32s) { + std::ofstream fout(fname); + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + static const char * font = params.font_path.c_str(); + + std::ifstream fin(font); + if (!fin.is_open()) { + fprintf(stderr, "%s: font not found at '%s', please specify a monospace font with -fp\n", __func__, font); + return false; + } + + fout << "#!/bin/bash" << "\n"; + fout << "\n"; + + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; + + for (int i = 0; i < whisper_full_n_segments(ctx); i++) { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + const int n = whisper_full_n_tokens(ctx, i); + + std::vector tokens(n); + for (int j = 0; j < n; ++j) { + tokens[j] = whisper_full_get_token_data(ctx, i, j); + } + + if (i > 0) { + fout << ","; + } + + // background text + fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; + + bool is_first = true; + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + for (int j = 0; j < n; ++j) { + const auto & token = tokens[j]; + + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + std::string txt_bg = ""; + std::string txt_fg = ""; // highlight token + std::string txt_ul = ""; // underline + + if (params.diarize && pcmf32s.size() == 2) { + txt_bg = speaker; + txt_fg = speaker; + txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ "; + } + + txt_bg.append("> "); + txt_fg.append("> "); + txt_ul.append("\\ \\ "); + + { + for (int k = 0; k < n; ++k) { + const auto & token2 = tokens[k]; + + if (tokens[k].id >= whisper_token_eot(ctx)) { + continue; + } + + const std::string txt = whisper_token_to_str(ctx, token2.id); + + txt_bg += txt; + + if (k == j) { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += txt[l]; + txt_ul += "_"; + } + txt_fg += "|"; + } else { + for (int l = 0; l < (int) txt.size(); ++l) { + txt_fg += "\\ "; + txt_ul += "\\ "; + } + } + } + + ::l_replace_all(txt_bg, "'", "\u2019"); + ::l_replace_all(txt_bg, "\"", "\\\""); + ::l_replace_all(txt_fg, "'", "\u2019"); + ::l_replace_all(txt_fg, "\"", "\\\""); + } + + if (is_first) { + // background text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'"; + is_first = false; + } + + // foreground text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + + // underline + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2+16:text='" << txt_ul << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; + } + } + + fout << "\" -c:v libx264 -pix_fmt yuv420p -y " << fname_inp << ".mp4" << "\n"; + + fout << "\n\n"; + fout << "echo \"Your video has been saved to " << fname_inp << ".mp4\"" << "\n"; + fout << "\n"; + fout << "echo \" ffplay " << fname_inp << ".mp4\"\n"; + fout << "\n"; + + fout.close(); + + fprintf(stderr, "%s: run 'source %s' to generate karaoke video\n", __func__, fname); + + return true; +} + +bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { + std::ofstream fout(fname); + if (!fout.is_open()) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); + return false; + } + + fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + + fout << "[by:whisper.cpp]\n"; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t = whisper_full_get_segment_t0(ctx, i); + + int64_t msec = t * 10; + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[16]; + snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10)); + std::string timestamp_lrc = std::string(buf); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + fout << '[' << timestamp_lrc << ']' << speaker << text << "\n"; + } + + return true; +} + + +void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + whisper_params params; + + // If the only argument starts with "@", read arguments line-by-line + // from the given file. + std::vector vec_args; + if (argc == 2 && argv != nullptr && argv[1] != nullptr && argv[1][0] == '@') { + // Save the name of the executable. + vec_args.push_back(argv[0]); + + // Open the response file. + char const * rspfile = argv[1] + sizeof(char); + std::ifstream fin(rspfile); + if (fin.is_open() == false) { + fprintf(stderr, "error: response file '%s' not found\n", rspfile); + return 1; + } + + // Read the entire response file. + std::string line; + while (std::getline(fin, line)) { + vec_args.push_back(line); + } + + // Use the contents of the response file as the command-line arguments. + argc = static_cast(vec_args.size()); + argv = static_cast(alloca(argc * sizeof (char *))); + for (int i = 0; i < argc; ++i) { + argv[i] = const_cast(vec_args[i].c_str()); + } + } + + if (whisper_params_parse(argc, argv, params) == false) { + whisper_print_usage(argc, argv, params); + return 1; + } + + // remove non-existent files + for (auto it = params.fname_inp.begin(); it != params.fname_inp.end();) { + const auto fname_inp = it->c_str(); + + if (*it != "-" && !is_file_exist(fname_inp)) { + fprintf(stderr, "error: input file not found '%s'\n", fname_inp); + it = params.fname_inp.erase(it); + continue; + } + + it++; + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + whisper_print_usage(argc, argv, params); + return 2; + } + + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + + if (params.diarize && params.tinydiarize) { + fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n"); + whisper_print_usage(argc, argv, params); + exit(0); + } + + if (params.no_prints) { + whisper_log_set(cb_log_disable, NULL); + } + + // whisper init + + struct whisper_context_params cparams = whisper_context_default_params(); + + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + + if (!params.dtw.empty()) { + cparams.dtw_token_timestamps = true; + cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE; + + if (params.dtw == "tiny") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY; + if (params.dtw == "tiny.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_TINY_EN; + if (params.dtw == "base") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE; + if (params.dtw == "base.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_BASE_EN; + if (params.dtw == "small") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL; + if (params.dtw == "small.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_SMALL_EN; + if (params.dtw == "medium") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM; + if (params.dtw == "medium.en") cparams.dtw_aheads_preset = WHISPER_AHEADS_MEDIUM_EN; + if (params.dtw == "large.v1") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V1; + if (params.dtw == "large.v2") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V2; + if (params.dtw == "large.v3") cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; + + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); + return 3; + } + } + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); + + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return 3; + } + + // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured + whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + + if (!params.grammar.empty()) { + auto & grammar = params.grammar_parsed; + if (is_file_exist(params.grammar.c_str())) { + // read grammar from file + std::ifstream ifs(params.grammar.c_str()); + const std::string txt = std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); + grammar = grammar_parser::parse(txt.c_str()); + } else { + // read grammar from string + grammar = grammar_parser::parse(params.grammar.c_str()); + } + + // will be empty (default) if there are parse errors + if (grammar.rules.empty()) { + fprintf(stderr, "error: failed to parse grammar \"%s\"\n", params.grammar.c_str()); + return 4; + } else { + fprintf(stderr, "%s: grammar:\n", __func__); + grammar_parser::print_grammar(stderr, grammar); + fprintf(stderr, "\n"); + } + } + + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; + const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; + } + + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + if (params.detect_language) { + params.language = "auto"; + } + + if (!params.no_prints) { + // print system information + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + + // print some info about the processing + fprintf(stderr, "\n"); + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, + params.n_threads, params.n_processors, params.beam_size, params.best_of, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.tinydiarize ? "tdrz = 1, " : "", + params.no_timestamps ? 0 : 1); + + fprintf(stderr, "\n"); + } + + // run the inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + const bool use_grammar = (!params.grammar_parsed.rules.empty() && !params.grammar_rule.empty()); + wparams.strategy = (params.beam_size > 1 || use_grammar) ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.detect_language = params.detect_language; + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.split_on_word = params.split_on_word; + wparams.audio_ctx = params.audio_ctx; + + wparams.speed_up = params.speed_up; + wparams.debug_mode = params.debug_mode; + + wparams.tdrz_enable = params.tinydiarize; // [TDRZ] + + wparams.suppress_regex = params.suppress_regex.empty() ? nullptr : params.suppress_regex.c_str(); + + wparams.initial_prompt = params.prompt.c_str(); + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + + wparams.no_timestamps = params.no_timestamps; + + whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; + + const auto & grammar_parsed = params.grammar_parsed; + // auto grammar_rules = grammar_parsed.c_rules(); + + // if (use_grammar) { + // if (grammar_parsed.symbol_ids.find(params.grammar_rule) == grammar_parsed.symbol_ids.end()) { + // fprintf(stderr, "%s: warning: grammar rule '%s' not found - skipping grammar sampling\n", __func__, params.grammar_rule.c_str()); + // } else { + // wparams.grammar_rules = grammar_rules.data(); + // wparams.n_grammar_rules = grammar_rules.size(); + // wparams.i_start_rule = grammar_parsed.symbol_ids.at(params.grammar_rule); + // wparams.grammar_penalty = params.grammar_penalty; + // } + // } + + // this callback is called on each new segment + if (!wparams.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + if (wparams.print_progress) { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + + // examples for abort mechanism + // in examples below, we do not abort the processing, but we could if the flag is set to true + + // the callback is called before every encoder run - if it returns false, the processing is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + + // the callback is called before every computation - if it returns true, the computation is aborted + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + wparams.abort_callback = [](void * user_data) { + bool is_aborted = *(bool*)user_data; + return is_aborted; + }; + wparams.abort_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + fprintf(stderr, "%s: failed to process audio\n", argv[0]); + return 10; + } + } + + // output stuff + { + printf("\n"); + + // output to text file + if (params.output_txt) { + const auto fname_txt = fname_out + ".txt"; + output_txt(ctx, fname_txt.c_str(), params, pcmf32s); + } + + // output to VTT file + if (params.output_vtt) { + const auto fname_vtt = fname_out + ".vtt"; + output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s); + } + + // output to SRT file + if (params.output_srt) { + const auto fname_srt = fname_out + ".srt"; + output_srt(ctx, fname_srt.c_str(), params, pcmf32s); + } + + // output to WTS file + if (params.output_wts) { + const auto fname_wts = fname_out + ".wts"; + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s); + } + + // output to CSV file + if (params.output_csv) { + const auto fname_csv = fname_out + ".csv"; + output_csv(ctx, fname_csv.c_str(), params, pcmf32s); + } + + // output to JSON file + if (params.output_jsn) { + const auto fname_jsn = fname_out + ".json"; + output_json(ctx, fname_jsn.c_str(), params, pcmf32s, params.output_jsn_full); + } + + // output to LRC file + if (params.output_lrc) { + const auto fname_lrc = fname_out + ".lrc"; + output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s); + } + + // output to score file + if (params.log_score) { + const auto fname_score = fname_out + ".score.txt"; + output_score(ctx, fname_score.c_str(), params, pcmf32s); + } + } + } + + if (!params.no_prints) { + whisper_print_timings(ctx); + } + whisper_free(ctx); + + return 0; +} diff --git a/otherarch/whispercpp/whisper.cpp b/otherarch/whispercpp/whisper.cpp new file mode 100644 index 000000000..d125f8e9c --- /dev/null +++ b/otherarch/whispercpp/whisper.cpp @@ -0,0 +1,7346 @@ +#include "whisper.h" + +#ifdef WHISPER_USE_COREML +#include "coreml/whisper-encoder.h" +#endif + +#ifdef GGML_USE_METAL +#include "ggml-metal.h" +#endif + +#ifdef GGML_USE_CUDA +#include "ggml-cuda.h" +#endif + +#ifdef GGML_USE_SYCL +#include "ggml-sycl.h" +#endif + +#ifdef WHISPER_USE_OPENVINO +#include "openvino/whisper-openvino-encoder.h" +#endif + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include +#include +#include +#define _USE_MATH_DEFINES +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +#if defined(GGML_BIG_ENDIAN) +#include + +template +static T byteswap(T value) { + return std::byteswap(value); +} + +template<> +float byteswap(float value) { + return std::bit_cast(byteswap(std::bit_cast(value))); +} + +template +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define WHISPER_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +WHISPER_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal (ggml_log_level level, const char * format, ...); +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define WHISPER_LOG_ERROR(...) whisper_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define WHISPER_LOG_WARN(...) whisper_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define WHISPER_LOG_INFO(...) whisper_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define WHISPER_DEBUG + +#if defined(WHISPER_DEBUG) +#define WHISPER_LOG_DEBUG(...) whisper_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define WHISPER_LOG_DEBUG(...) +#endif + +#define WHISPER_ASSERT(x) \ + do { \ + if (!(x)) { \ + WHISPER_LOG_ERROR("WHISPER_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +//#define WHISPER_USE_FLASH_FF +#define WHISPER_MAX_DECODERS 8 +#define WHISPER_MAX_NODES 4096 + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + std::vector & buf, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + + plan.abort_callback = abort_callback; + plan.abort_callback_data = abort_callback_data; + + if (plan.work_size > 0) { + buf.resize(plan.work_size); + plan.work_data = buf.data(); + } + + return ggml_graph_compute(graph, &plan); +} + +static bool ggml_graph_compute_helper( + struct ggml_backend * backend, + struct ggml_cgraph * graph, + int n_threads) { + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + return ggml_backend_graph_compute(backend, graph) == GGML_STATUS_SUCCESS; +} + +// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" +// the idea is to represent the original matrix multiplication: +// +// Z = X @ Y +// +// with the sum of two matrix multiplications: +// +// Z = (X_0 @ Y_0) + (X_1 @ Y_1) +// +// here X_0 and Y_0 are views of X and Y that have dimension 0 divisible by "pad" +// and X_1 and Y_1 are the remaining views. X_1 and Y_1 end up being small matrices that can be processed with more +// general-purpose kernels +// +static struct ggml_tensor * ggml_mul_mat_pad(struct ggml_context * ctx, struct ggml_tensor * x, struct ggml_tensor * y, int pad = 32) { + // use padding only if dimension 0 is at least 8 times larger than the padding + // else we won't get much benefit from the optimization + const int n_pad_req = 8; + + if (x->ne[0] % pad == 0 || x->ne[0] / pad < n_pad_req) { + return ggml_mul_mat(ctx, x, y); + } + + struct ggml_tensor * x_0 = ggml_view_3d(ctx, x, (x->ne[0]/pad)*pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], 0); + struct ggml_tensor * x_1 = ggml_view_3d(ctx, x, x->ne[0]%pad, x->ne[1], x->ne[2], x->nb[1], x->nb[2], x_0->ne[0]*x_0->nb[0]); + + struct ggml_tensor * y_0 = ggml_view_3d(ctx, y, (y->ne[0]/pad)*pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], 0); + struct ggml_tensor * y_1 = ggml_view_3d(ctx, y, y->ne[0]%pad, y->ne[1], y->ne[2], y->nb[1], y->nb[2], y_0->ne[0]*y_0->nb[0]); + + return ggml_add(ctx, + ggml_mul_mat(ctx, x_0, y_0), + ggml_mul_mat(ctx, x_1, y_1)); +} + +// TODO: check if other platforms can benefit from this optimization +// TODO: CUDA is currently broken - seems ggml_mul_mat does not handle views correctly +#if defined(GGML_USE_METAL) +#define ggml_mul_mat ggml_mul_mat_pad +#endif + +// available whisper models +enum e_model { + MODEL_UNKNOWN, + MODEL_TINY, + MODEL_BASE, + MODEL_SMALL, + MODEL_MEDIUM, + MODEL_LARGE, +}; + +static const std::map g_model_name = { + { MODEL_UNKNOWN, "unknown" }, + { MODEL_TINY, "tiny" }, + { MODEL_BASE, "base" }, + { MODEL_SMALL, "small" }, + { MODEL_MEDIUM, "medium" }, + { MODEL_LARGE, "large" }, +}; + +static const std::map> g_lang = { + { "en", { 0, "english", } }, + { "zh", { 1, "chinese", } }, + { "de", { 2, "german", } }, + { "es", { 3, "spanish", } }, + { "ru", { 4, "russian", } }, + { "ko", { 5, "korean", } }, + { "fr", { 6, "french", } }, + { "ja", { 7, "japanese", } }, + { "pt", { 8, "portuguese", } }, + { "tr", { 9, "turkish", } }, + { "pl", { 10, "polish", } }, + { "ca", { 11, "catalan", } }, + { "nl", { 12, "dutch", } }, + { "ar", { 13, "arabic", } }, + { "sv", { 14, "swedish", } }, + { "it", { 15, "italian", } }, + { "id", { 16, "indonesian", } }, + { "hi", { 17, "hindi", } }, + { "fi", { 18, "finnish", } }, + { "vi", { 19, "vietnamese", } }, + { "he", { 20, "hebrew", } }, + { "uk", { 21, "ukrainian", } }, + { "el", { 22, "greek", } }, + { "ms", { 23, "malay", } }, + { "cs", { 24, "czech", } }, + { "ro", { 25, "romanian", } }, + { "da", { 26, "danish", } }, + { "hu", { 27, "hungarian", } }, + { "ta", { 28, "tamil", } }, + { "no", { 29, "norwegian", } }, + { "th", { 30, "thai", } }, + { "ur", { 31, "urdu", } }, + { "hr", { 32, "croatian", } }, + { "bg", { 33, "bulgarian", } }, + { "lt", { 34, "lithuanian", } }, + { "la", { 35, "latin", } }, + { "mi", { 36, "maori", } }, + { "ml", { 37, "malayalam", } }, + { "cy", { 38, "welsh", } }, + { "sk", { 39, "slovak", } }, + { "te", { 40, "telugu", } }, + { "fa", { 41, "persian", } }, + { "lv", { 42, "latvian", } }, + { "bn", { 43, "bengali", } }, + { "sr", { 44, "serbian", } }, + { "az", { 45, "azerbaijani", } }, + { "sl", { 46, "slovenian", } }, + { "kn", { 47, "kannada", } }, + { "et", { 48, "estonian", } }, + { "mk", { 49, "macedonian", } }, + { "br", { 50, "breton", } }, + { "eu", { 51, "basque", } }, + { "is", { 52, "icelandic", } }, + { "hy", { 53, "armenian", } }, + { "ne", { 54, "nepali", } }, + { "mn", { 55, "mongolian", } }, + { "bs", { 56, "bosnian", } }, + { "kk", { 57, "kazakh", } }, + { "sq", { 58, "albanian", } }, + { "sw", { 59, "swahili", } }, + { "gl", { 60, "galician", } }, + { "mr", { 61, "marathi", } }, + { "pa", { 62, "punjabi", } }, + { "si", { 63, "sinhala", } }, + { "km", { 64, "khmer", } }, + { "sn", { 65, "shona", } }, + { "yo", { 66, "yoruba", } }, + { "so", { 67, "somali", } }, + { "af", { 68, "afrikaans", } }, + { "oc", { 69, "occitan", } }, + { "ka", { 70, "georgian", } }, + { "be", { 71, "belarusian", } }, + { "tg", { 72, "tajik", } }, + { "sd", { 73, "sindhi", } }, + { "gu", { 74, "gujarati", } }, + { "am", { 75, "amharic", } }, + { "yi", { 76, "yiddish", } }, + { "lo", { 77, "lao", } }, + { "uz", { 78, "uzbek", } }, + { "fo", { 79, "faroese", } }, + { "ht", { 80, "haitian creole", } }, + { "ps", { 81, "pashto", } }, + { "tk", { 82, "turkmen", } }, + { "nn", { 83, "nynorsk", } }, + { "mt", { 84, "maltese", } }, + { "sa", { 85, "sanskrit", } }, + { "lb", { 86, "luxembourgish", } }, + { "my", { 87, "myanmar", } }, + { "bo", { 88, "tibetan", } }, + { "tl", { 89, "tagalog", } }, + { "mg", { 90, "malagasy", } }, + { "as", { 91, "assamese", } }, + { "tt", { 92, "tatar", } }, + { "haw", { 93, "hawaiian", } }, + { "ln", { 94, "lingala", } }, + { "ha", { 95, "hausa", } }, + { "ba", { 96, "bashkir", } }, + { "jw", { 97, "javanese", } }, + { "su", { 98, "sundanese", } }, + { "yue", { 99, "cantonese", } }, +}; + +// [EXPERIMENTAL] Token-level timestamps with DTW +static const whisper_ahead g_aheads_tiny_en[] = { {1, 0}, {2, 0}, {2, 5}, {3, 0}, {3, 1}, {3, 2}, {3, 3}, {3, 4} }; +static const whisper_ahead g_aheads_tiny[] = { {2, 2}, {3, 0}, {3, 2}, {3, 3}, {3, 4}, {3, 5} }; +static const whisper_ahead g_aheads_base_en[] = { {3, 3}, {4, 7}, {5, 1}, {5, 5}, {5, 7} }; +static const whisper_ahead g_aheads_base[] = { {3, 1}, {4, 2}, {4, 3}, {4, 7}, {5, 1}, {5, 2}, {5, 4}, {5, 6} }; +static const whisper_ahead g_aheads_small_en[] = { {6, 6}, {7, 0}, {7, 3}, {7, 8}, {8, 2}, {8, 5}, {8, 7}, {9, 0}, {9, 4}, {9, 8}, {9, 10}, {10, 0}, {10, 1}, {10, 2}, {10, 3}, {10, 6}, {10, 11}, {11, 2}, {11, 4} }; +static const whisper_ahead g_aheads_small[] = { {5, 3}, {5, 9}, {8, 0}, {8, 4}, {8, 7}, {8, 8}, {9, 0}, {9, 7}, {9, 9}, {10, 5} }; +static const whisper_ahead g_aheads_medium_en[] = { {11, 4}, {14, 1}, {14, 12}, {14, 14}, {15, 4}, {16, 0}, {16, 4}, {16, 9}, {17, 12}, {17, 14}, {18, 7}, {18, 10}, {18, 15}, {20, 0}, {20, 3}, {20, 9}, {20, 14}, {21, 12} }; +static const whisper_ahead g_aheads_medium[] = { {13, 15}, {15, 4}, {15, 15}, {16, 1}, {20, 0}, {23, 4} }; +static const whisper_ahead g_aheads_large_v1[] = { {9, 19}, {11, 2}, {11, 4}, {11, 17}, {22, 7}, {22, 11}, {22, 17}, {23, 2}, {23, 15} }; +static const whisper_ahead g_aheads_large_v2[] = { {10, 12}, {13, 17}, {16, 11}, {16, 12}, {16, 13}, {17, 15}, {17, 16}, {18, 4}, {18, 11}, {18, 19}, {19, 11}, {21, 2}, {21, 3}, {22, 3}, {22, 9}, {22, 12}, {23, 5}, {23, 7}, {23, 13}, {25, 5}, {26, 1}, {26, 12}, {27, 15} }; +static const whisper_ahead g_aheads_large_v3[] = { {7, 0}, {10, 17}, {12, 18}, {13, 12}, {16, 1}, {17, 14}, {19, 11}, {21, 4}, {24, 1}, {25, 6} }; + +static const std::map g_aheads { + { WHISPER_AHEADS_TINY_EN, { 8, g_aheads_tiny_en } }, + { WHISPER_AHEADS_TINY, { 6, g_aheads_tiny } }, + { WHISPER_AHEADS_BASE_EN, { 5, g_aheads_base_en } }, + { WHISPER_AHEADS_BASE, { 8, g_aheads_base } }, + { WHISPER_AHEADS_SMALL_EN, { 19, g_aheads_small_en } }, + { WHISPER_AHEADS_SMALL, { 10, g_aheads_small } }, + { WHISPER_AHEADS_MEDIUM_EN, { 18, g_aheads_medium_en } }, + { WHISPER_AHEADS_MEDIUM, { 6, g_aheads_medium } }, + { WHISPER_AHEADS_LARGE_V1, { 9, g_aheads_large_v1 } }, + { WHISPER_AHEADS_LARGE_V2, { 23, g_aheads_large_v2 } }, + { WHISPER_AHEADS_LARGE_V3, { 10, g_aheads_large_v3 } }, +}; + +static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int32_t n_text_layer, int32_t n_head); + +struct whisper_mel { + int n_len; + int n_len_org; + int n_mel; + + std::vector data; +}; + +struct whisper_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +struct whisper_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 51864; + + std::map token_to_id; + std::map id_to_token; + + // reference: https://github.com/openai/whisper/blob/248b6cb124225dd263bb9bd32d060b6517e067f8/whisper/tokenizer.py#L334-L349 + id token_eot = 50256; + id token_sot = 50257; + // task tokens (used only for multilingual models) + id token_translate = 50357; + id token_transcribe = 50358; + // other special tokens + id token_solm = 50359; // [TDRZ] used by tinydiarize models to indicate speaker turn + id token_prev = 50360; + id token_nosp = 50361; + id token_not = 50362; // no timestamps + id token_beg = 50363; // begin timestamps + + bool is_multilingual() const { + return n_vocab >= 51865; + } + + int num_languages() const { + return n_vocab - 51765 - (is_multilingual() ? 1 : 0); + } +}; + +struct whisper_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector tokens; + + bool speaker_turn_next; +}; + +struct whisper_batch { + int32_t n_tokens; + + whisper_token * token; + whisper_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + whisper_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +static struct whisper_batch whisper_batch_init(int32_t n_tokens, int32_t n_seq_max) { + whisper_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (whisper_token * ) malloc(sizeof(whisper_token) * (n_tokens)); + batch.pos = (whisper_pos *) malloc(sizeof(whisper_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (whisper_seq_id **) malloc(sizeof(whisper_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (whisper_seq_id *) malloc(sizeof(whisper_seq_id) * n_seq_max); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void whisper_batch_free(struct whisper_batch batch) { + if (batch.token) free(batch.token); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair { + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A& a, const B& b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_gallocr_t alloc = nullptr; + + std::vector meta; +}; + +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + ggml_gallocr_get_buffer_size(allocr.alloc, 0); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static bool whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + + alloc = ggml_gallocr_new(ggml_backend_get_default_buffer_type(backend)); + + meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); + + // since there are dependencies between the different graphs, + // we need to allocate them instead of only reserving to get the correct compute buffer size + if (!ggml_gallocr_alloc_graph(alloc, get_graph())) { + // failed to allocate the compute buffer + WHISPER_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + return false; + } + return true; +} + +// medium +// hparams: { +// 'n_mels': 80, +// 'n_vocab': 51864, +// 'n_audio_ctx': 1500, +// 'n_audio_state': 1024, +// 'n_audio_head': 16, +// 'n_audio_layer': 24, +// 'n_text_ctx': 448, +// 'n_text_state': 1024, +// 'n_text_head': 16, +// 'n_text_layer': 24 +// } +// +// default hparams (Whisper tiny) +struct whisper_hparams { + int32_t n_vocab = 51864; + int32_t n_audio_ctx = 1500; + int32_t n_audio_state = 384; + int32_t n_audio_head = 6; + int32_t n_audio_layer = 4; + int32_t n_text_ctx = 448; + int32_t n_text_state = 384; + int32_t n_text_head = 6; + int32_t n_text_layer = 4; + int32_t n_mels = 80; + int32_t ftype = 1; + float eps = 1e-5f; +}; + +// audio encoding layer +struct whisper_layer_encoder { + // encoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // encoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // encoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // encoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // encoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // encoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // encoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // encoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +// token decoding layer +struct whisper_layer_decoder { + // decoder.blocks.*.attn_ln + struct ggml_tensor * attn_ln_0_w; + struct ggml_tensor * attn_ln_0_b; + + // decoder.blocks.*.attn.out + struct ggml_tensor * attn_ln_1_w; + struct ggml_tensor * attn_ln_1_b; + + // decoder.blocks.*.attn.query + struct ggml_tensor * attn_q_w; + struct ggml_tensor * attn_q_b; + + // decoder.blocks.*.attn.key + struct ggml_tensor * attn_k_w; + + // decoder.blocks.*.attn.value + struct ggml_tensor * attn_v_w; + struct ggml_tensor * attn_v_b; + + // decoder.blocks.*.cross_attn_ln + struct ggml_tensor * cross_attn_ln_0_w; + struct ggml_tensor * cross_attn_ln_0_b; + + // decoder.blocks.*.cross_attn.out + struct ggml_tensor * cross_attn_ln_1_w; + struct ggml_tensor * cross_attn_ln_1_b; + + // decoder.blocks.*.cross_attn.query + struct ggml_tensor * cross_attn_q_w; + struct ggml_tensor * cross_attn_q_b; + + // decoder.blocks.*.cross_attn.key + struct ggml_tensor * cross_attn_k_w; + + // decoder.blocks.*.cross_attn.value + struct ggml_tensor * cross_attn_v_w; + struct ggml_tensor * cross_attn_v_b; + + // decoder.blocks.*.mlp_ln + struct ggml_tensor * mlp_ln_w; + struct ggml_tensor * mlp_ln_b; + + // decoder.blocks.*.mlp.0 + struct ggml_tensor * mlp_0_w; + struct ggml_tensor * mlp_0_b; + + // decoder.blocks.*.mlp.2 + struct ggml_tensor * mlp_1_w; + struct ggml_tensor * mlp_1_b; +}; + +struct whisper_kv_cell { + whisper_pos pos = -1; + + std::set seq_id; + + bool has_seq_id(const whisper_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } +}; + +struct whisper_kv_cache { + uint32_t head = 0; + uint32_t size = 0; + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + struct ggml_tensor * k; + struct ggml_tensor * v; + + struct ggml_context * ctx = nullptr; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct whisper_model { + e_model type = MODEL_UNKNOWN; + + whisper_hparams hparams; + whisper_filters filters; + + // encoder.positional_embedding + struct ggml_tensor * e_pe; + + // encoder.conv1 + struct ggml_tensor * e_conv_1_w; + struct ggml_tensor * e_conv_1_b; + + // encoder.conv2 + struct ggml_tensor * e_conv_2_w; + struct ggml_tensor * e_conv_2_b; + + // encoder.ln_post + struct ggml_tensor * e_ln_w; + struct ggml_tensor * e_ln_b; + + // decoder.positional_embedding + struct ggml_tensor * d_pe; + + // decoder.token_embedding + struct ggml_tensor * d_te; + + // decoder.ln + struct ggml_tensor * d_ln_w; + struct ggml_tensor * d_ln_b; + + std::vector layers_encoder; + std::vector layers_decoder; + + // ggml context that contains all the meta information about the model tensors + struct ggml_context * ctx = nullptr; + + // the model backend data is read-only and can be shared between processors + ggml_backend_buffer_t buffer = nullptr; + + // tensors + int n_loaded; + std::map tensors; +}; + +struct whisper_partial_utf8 { + uint32_t value; // bit value so far (unshifted) + int n_remain; // num bytes remaining; -1 indicates invalid sequence +}; + +struct whisper_grammar { + /*const*/ std::vector> rules; + std::vector> stacks; + + // buffer for partially generated UTF-8 sequence from accepted tokens + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_grammar_candidate { + whisper_token id; + const uint32_t * code_points; + whisper_partial_utf8 partial_utf8; +}; + +struct whisper_sequence { + std::vector tokens; + + // the accumulated transcription in the current iteration (used to truncate the tokens array) + int result_len; + + double sum_logprobs_all; // the sum of the log probabilities of the tokens + double sum_logprobs; // the sum of the log probabilities of the tokens (first result_len tokens) + double avg_logprobs; // the average log probability of the tokens + double entropy; // the entropy of the tokens + double score; // likelihood rank score +}; + +// TAGS: WHISPER_DECODER_INIT +struct whisper_decoder { + // the currently generated sequence of tokens + whisper_sequence sequence; + + // grammar parse state of generated sequence of tokens + whisper_grammar grammar; + + int i_batch; // the index of the token in the current batch + int seek_delta; // the window shift found so far based on the decoded timestamp tokens + + bool failed; // has the current segment failed to decode? + bool completed; // has the decoder completed the current segment? + bool has_ts; // have we already sampled a non-beg timestamp token for the current segment? + + // new token probs, logits and logprobs after the last whisper_decode (1-dimensional array: [n_vocab]) + std::vector probs; + std::vector logits; + std::vector logprobs; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 +}; + +// [EXPERIMENTAL] Token-level timestamps with DTW +struct whisper_aheads_masks { + std::vector m; // One mask per text layer. + struct ggml_context * ctx = nullptr; + ggml_backend_buffer_t buffer = nullptr; +}; + +struct whisper_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_batchd_us = 0; + int64_t t_prompt_us = 0; + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_batchd = 0; // number of decoder calls with n_tokens < 16 (batch decoding) + int32_t n_prompt = 0; // number of decoder calls with n_tokens > 1 (prompt encoding) + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + // unified self-attention KV cache for all decoders + whisper_kv_cache kv_self; + + // cross-attention KV cache for the decoders + // shared between all decoders + whisper_kv_cache kv_cross; + + // padded buffer for flash-attention + whisper_kv_cache kv_pad; + + whisper_mel mel; + + whisper_batch batch; + + whisper_decoder decoders[WHISPER_MAX_DECODERS]; + + ggml_backend_t backend = nullptr; + + // ggml-alloc: + // - stores meta info about the intermediate tensors into the `meta` buffers + // - stores the actual tensor data into the `data` buffers + whisper_allocr alloc_conv; + whisper_allocr alloc_encode; + whisper_allocr alloc_cross; + whisper_allocr alloc_decode; + + // result of the encoder + struct ggml_tensor * embd_conv = nullptr; + struct ggml_tensor * embd_enc = nullptr; + + // helpers for GPU offloading + std::vector inp_mel; + std::vector inp_mask; + + // decode output (2-dimensional array: [n_tokens][n_vocab]) + std::vector logits; + + std::vector result_all; + std::vector prompt_past; + + int lang_id = 0; // english by default + + std::string path_model; // populated by whisper_init_from_file_with_params() + +#ifdef WHISPER_USE_COREML + whisper_coreml_context * ctx_coreml = nullptr; +#endif + +#ifdef WHISPER_USE_OPENVINO + whisper_openvino_context * ctx_openvino = nullptr; +#endif + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg = 0; + int64_t t_last = 0; + + whisper_token tid_last; + + std::vector energy; // PCM signal energy + + // [EXPERIMENTAL] Token-level timestamps with DTW + whisper_aheads_masks aheads_masks; + ggml_tensor * aheads_cross_QKs = nullptr; + std::vector aheads_cross_QKs_data; + + // [EXPERIMENTAL] speed-up techniques + int32_t exp_n_audio_ctx = 0; // 0 - use default +}; + +struct whisper_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; // weight type (FP32 / FP16 / QX) + ggml_type itype = ggml_type::GGML_TYPE_F16; // intermediate type (FP32 or FP16) + + whisper_context_params params; + + whisper_model model; + whisper_vocab vocab; + + whisper_state * state = nullptr; + + ggml_backend_t backend = nullptr; + + std::string path_model; // populated by whisper_init_from_file_with_params() +}; + +struct whisper_global { + // We save the log callback globally + ggml_log_callback log_callback = whisper_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static whisper_global g_state; + +template +static void read_safe(whisper_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool kv_cache_init( + struct whisper_kv_cache & cache, + ggml_backend_t backend, + ggml_type wtype, + int64_t n_text_state, + int64_t n_text_layer, + int n_ctx) { + const int64_t n_mem = n_text_layer*n_ctx; + const int64_t n_elements = n_text_state*n_mem; + + struct ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + cache.head = 0; + cache.size = n_ctx; + + cache.cells.clear(); + cache.cells.resize(n_ctx); + + cache.ctx = ggml_init(params); + + if (!cache.ctx) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache context\n", __func__); + return false; + } + + cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); + + cache.buffer = ggml_backend_alloc_ctx_tensors(cache.ctx, backend); + if (!cache.buffer) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the kv cache\n", __func__); + return false; + } + + ggml_backend_buffer_clear(cache.buffer, 0); + + return true; +} + +static void kv_cache_free(struct whisper_kv_cache & cache) { + ggml_free(cache.ctx); + ggml_backend_buffer_free(cache.buffer); + cache.ctx = nullptr; +} + +static bool whisper_kv_cache_find_slot( + struct whisper_kv_cache & cache, + const struct whisper_batch & batch) { + const uint32_t n_ctx = cache.size; + const uint32_t n_tokens = batch.n_tokens; + + if (n_tokens > n_ctx) { + WHISPER_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.head + n_tokens > n_ctx) { + n_tested += n_ctx - cache.head; + cache.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.cells[cache.head + i].pos >= 0) { + found = false; + cache.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= n_ctx) { + //WHISPER_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.cells[cache.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + return true; +} + +// find how many cells are currently in use +static int32_t whisper_kv_cache_cell_max(const struct whisper_kv_cache & cache) { + for (uint32_t i = cache.size - 1; i > 0; --i) { + if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { + return i + 1; + } + } + + return 1; +} + +static void whisper_kv_cache_clear(struct whisper_kv_cache & cache) { + for (int32_t i = 0; i < (int32_t) cache.size; ++i) { + cache.cells[i].pos = -1; + cache.cells[i].seq_id.clear(); + } + cache.head = 0; +} + +static void whisper_kv_cache_seq_rm( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id, + whisper_pos p0, + whisper_pos p1) { + uint32_t new_head = cache.size; + + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cache.cells[i].seq_id.empty()) { + cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; +} + +static void whisper_kv_cache_seq_cp( + struct whisper_kv_cache & cache, + whisper_seq_id seq_id_src, + whisper_seq_id seq_id_dst, + whisper_pos p0, + whisper_pos p1) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + cache.head = 0; + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.cells[i].seq_id.insert(seq_id_dst); + } + } +} + +static uint32_t whisper_kv_cache_get_padding(const struct whisper_context & wctx) { + if (!wctx.params.flash_attn) { + return 1u; + } + +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(wctx.backend)) { + return 32u; + } +#endif + +#ifdef GGML_USE_CUDA + if (ggml_backend_is_cuda(wctx.backend)) { + return 256u; + } +#endif + + return 1u; +} + +// [EXPERIMENTAL] Token-level timestamps with DTW +static bool aheads_masks_init( + const whisper_context_params & cparams, + const whisper_hparams & hparams, + struct whisper_aheads_masks & aheads_masks, + ggml_backend_t backend) { + + const int32_t n_text_layer = hparams.n_text_layer; + const int32_t n_head = hparams.n_text_head; + + // Sanity checks + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + WHISPER_LOG_ERROR("%s: dtw_aheads_preset should be != DTW_AHEADS_NONE\n", __func__); + return false; + } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { + if (cparams.dtw_n_top > n_text_layer || cparams.dtw_n_top <= 0) { + WHISPER_LOG_ERROR("%s: dtw_n_top must be between %d and %d for this model.", __func__, 1, n_text_layer); + return false; + } + } else { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM) { + if (aheads.n_heads == 0) { + WHISPER_LOG_ERROR("%s: dtw_aheads.n_heads should be > 0", __func__); + return false; + } + if (aheads.heads == NULL) { + WHISPER_LOG_ERROR("%s: dtw_aheads.heads unset", __func__); + return false; + } + } + for (size_t i = 0; i < aheads.n_heads; ++i) { + if (aheads.heads[i].n_text_layer >= n_text_layer) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer %d, but model only has %d text layers", __func__, aheads.heads[i].n_text_layer + 1, n_text_layer); + return false; + } + if (aheads.heads[i].n_text_layer < 0) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on text layer < 0", __func__); + return false; + } + if (aheads.heads[i].n_head >= n_head) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head %d, but model only has %d heads", __func__, aheads.heads[i].n_head + 1, n_head); + return false; + } + if (aheads.heads[i].n_head < 0) { + WHISPER_LOG_ERROR("%s: tried to set alignment head on head < 0", __func__); + return false; + } + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ (size_t) static_cast(n_text_layer)*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + aheads_masks.ctx = ggml_init(params); + + if (!aheads_masks.ctx) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the aheads_masks context\n", __func__); + return false; + } + + for (int64_t il = 0; il < n_text_layer; ++il) { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + if (!aheads.empty()) { + aheads_masks.m.push_back(ggml_new_tensor_2d(aheads_masks.ctx, GGML_TYPE_F32, n_head, aheads.size())); + } else { + aheads_masks.m.push_back(nullptr); + } + } + + aheads_masks.buffer = ggml_backend_alloc_ctx_tensors(aheads_masks.ctx, backend); + if (!aheads_masks.buffer) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for aheads_masks\n", __func__); + return false; + } + + // Set data on mask tensors + // Since this must be backend agnostic, we write our desired values on mask_data, + // and send it to backend with ggml_backend_tensor_set. + // Each mask in N_HEADS*N_ALIGNMENT_HEADS, one per text layer containing alignment + // heads. Each row of the mask "marks" one alignment head. E.g. if some text layer + // has a total of 10 heads and of those, heads 0,5,6 are alignment heads, the mask + // should read: + // 1 0 0 0 0 0 0 0 0 0 + // 0 0 0 0 0 1 0 0 0 0 + // 0 0 0 0 0 0 1 0 0 0 + std::vector mask_data; + for (int64_t il = 0; il < n_text_layer; ++il) { + if (aheads_masks.m[il] != nullptr) { + auto aheads = get_alignment_heads_by_layer(cparams, il, n_text_layer, n_head); + + size_t data_size = aheads_masks.m[il]->ne[0] * aheads_masks.m[il]->ne[1]; + size_t data_size_bytes = data_size * sizeof(float); + mask_data.resize(data_size); + + std::fill(mask_data.begin(), mask_data.end(), 0); + for (size_t ih = 0; ih < aheads.size(); ++ih) { + size_t pos = (aheads[ih] + (ih * aheads_masks.m[il]->ne[0])); + mask_data[pos] = 1.0f; + } + + ggml_backend_tensor_set(aheads_masks.m[il], mask_data.data(), 0, data_size_bytes); + } + } + + if (aheads_masks.m.empty()) { + WHISPER_LOG_ERROR("%s: \n", __func__); + return false; + } + + return true; +} + +static void aheads_masks_free(struct whisper_aheads_masks & aheads_masks) { + ggml_free(aheads_masks.ctx); + ggml_backend_buffer_free(aheads_masks.buffer); + aheads_masks.ctx = nullptr; +} + +static size_t aheads_masks_nbytes(struct whisper_aheads_masks & aheads_masks) { + size_t size = 0; + for (size_t i = 0; i < aheads_masks.m.size(); ++i) { + if (aheads_masks.m[i] != nullptr) + size += ggml_nbytes(aheads_masks.m[i]); + } + return size; +} + +static ggml_backend_t whisper_backend_init(const whisper_context_params & params) { + ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef GGML_USE_CUDA + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = ggml_backend_cuda_init(params.gpu_device); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_backend_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data); + backend_gpu = ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } else if (!ggml_backend_metal_supports_family(backend_gpu, 7)) { + WHISPER_LOG_ERROR("%s: Metal GPU does not support family 7 - falling back to CPU\n", __func__); + ggml_backend_free(backend_gpu); + backend_gpu = NULL; + } + } +#endif + +#ifdef GGML_USE_SYCL + if (params.use_gpu) { + WHISPER_LOG_INFO("%s: using SYCL backend\n", __func__); + backend_gpu = ggml_backend_sycl_init(params.gpu_device); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_sycl_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + return backend_gpu; + } + return ggml_backend_cpu_init(); +} + +// load the model from a ggml file +// +// file format: +// +// - hparams +// - pre-computed mel filters +// - vocab +// - weights +// +// see the convert-pt-to-ggml.py script for details +// +static bool whisper_model_load(struct whisper_model_loader * loader, whisper_context & wctx) { + WHISPER_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + WHISPER_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + { + auto & hparams = model.hparams; + + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_text_ctx); + read_safe(loader, hparams.n_text_state); + read_safe(loader, hparams.n_text_head); + read_safe(loader, hparams.n_text_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + + assert(hparams.n_text_state == hparams.n_audio_state); + + std::string mver = ""; + + if (hparams.n_audio_layer == 4) { + model.type = e_model::MODEL_TINY; + } + + if (hparams.n_audio_layer == 6) { + model.type = e_model::MODEL_BASE; + } + + if (hparams.n_audio_layer == 12) { + model.type = e_model::MODEL_SMALL; + } + + if (hparams.n_audio_layer == 24) { + model.type = e_model::MODEL_MEDIUM; + } + + if (hparams.n_audio_layer == 32) { + model.type = e_model::MODEL_LARGE; + + if (hparams.n_vocab == 51866) { + mver = " v3"; + } + } + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) (model.hparams.ftype)); + if (wctx.wtype == GGML_TYPE_COUNT) { + WHISPER_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, model.hparams.ftype); + return false; + } + + WHISPER_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + WHISPER_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + WHISPER_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + WHISPER_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + WHISPER_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + WHISPER_LOG_INFO("%s: n_text_ctx = %d\n", __func__, hparams.n_text_ctx); + WHISPER_LOG_INFO("%s: n_text_state = %d\n", __func__, hparams.n_text_state); + WHISPER_LOG_INFO("%s: n_text_head = %d\n", __func__, hparams.n_text_head); + WHISPER_LOG_INFO("%s: n_text_layer = %d\n", __func__, hparams.n_text_layer); + WHISPER_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + WHISPER_LOG_INFO("%s: ftype = %d\n", __func__, model.hparams.ftype); + WHISPER_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + WHISPER_LOG_INFO("%s: type = %d (%s%s)\n", __func__, model.type, g_model_name.at(model.type).c_str(), mver.c_str()); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fft); + + filters.data.resize(filters.n_mel * filters.n_fft); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + //if (n_vocab != model.hparams.n_vocab) { + // WHISPER_LOG_ERROR("%s: invalid model file '%s' (bad vocab size %d != %d)\n", + // __func__, fname.c_str(), n_vocab, model.hparams.n_vocab); + // return false; + //} + + std::string word; + std::vector tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + // seems like we have an empty-string token in multi-language models (i = 50256) + //WHISPER_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); + } + + vocab.n_vocab = model.hparams.n_vocab; + if (vocab.is_multilingual()) { + vocab.token_eot++; + vocab.token_sot++; + + // account for variable number of language tokens + const int dt = vocab.num_languages() - 98; + + vocab.token_translate += dt; + vocab.token_transcribe += dt; + vocab.token_solm += dt; + vocab.token_prev += dt; + vocab.token_nosp += dt; + vocab.token_not += dt; + vocab.token_beg += dt; + } + + if (n_vocab < model.hparams.n_vocab) { + WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab); + for (int i = n_vocab; i < model.hparams.n_vocab; i++) { + if (i > vocab.token_beg) { + word = "[_TT_" + std::to_string(i - vocab.token_beg) + "]"; + } else if (i == vocab.token_eot) { + word = "[_EOT_]"; + } else if (i == vocab.token_sot) { + word = "[_SOT_]"; + } else if (i == vocab.token_translate) { + word = "[_TRANSLATE_]"; + } else if (i == vocab.token_transcribe) { + word = "[_TRANSCRIBE_]"; + } else if (i == vocab.token_solm) { + word = "[_SOLM_]"; + } else if (i == vocab.token_prev) { + word = "[_PREV_]"; + } else if (i == vocab.token_nosp) { + word = "[_NOSP_]"; + } else if (i == vocab.token_not) { + word = "[_NOT_]"; + } else if (i == vocab.token_beg) { + word = "[_BEG_]"; + } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) { + word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; + } else { + word = "[_extra_token_" + std::to_string(i) + "]"; + } + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + } + } + + WHISPER_LOG_INFO("%s: n_langs = %d\n", __func__, vocab.num_languages()); + } + + const ggml_type wtype = wctx.wtype; + const ggml_type vtype = wctx.wtype == GGML_TYPE_F32 ? GGML_TYPE_F32 : GGML_TYPE_F16; // conv type + + // create the ggml context + { + const auto & hparams = model.hparams; + + const int n_audio_layer = hparams.n_audio_layer; + const int n_text_layer = hparams.n_text_layer; + + const size_t n_tensors = 10 /* input */ + 15 + 15*n_audio_layer + 24*n_text_layer; + + struct ggml_init_params params = { + /*.mem_size =*/ n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + model.ctx = ggml_init(params); + if (!model.ctx) { + WHISPER_LOG_ERROR("%s: ggml_init() failed\n", __func__); + return false; + } + } + + // prepare tensors for the weights + { + auto & ctx = model.ctx; + + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + + const int n_audio_ctx = hparams.n_audio_ctx; + const int n_audio_state = hparams.n_audio_state; + const int n_audio_layer = hparams.n_audio_layer; + + const int n_text_ctx = hparams.n_text_ctx; + const int n_text_state = hparams.n_text_state; + const int n_text_layer = hparams.n_text_layer; + + const int n_mels = hparams.n_mels; + + model.layers_encoder.resize(n_audio_layer); + model.layers_decoder.resize(n_text_layer); + + // encoder + { + model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); + + model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); + + model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.positional_embedding"] = model.e_pe; + + model.tensors["encoder.conv1.weight"] = model.e_conv_1_w; + model.tensors["encoder.conv1.bias"] = model.e_conv_1_b; + + model.tensors["encoder.conv2.weight"] = model.e_conv_2_w; + model.tensors["encoder.conv2.bias"] = model.e_conv_2_b; + + model.tensors["encoder.ln_post.weight"] = model.e_ln_w; + model.tensors["encoder.ln_post.bias"] = model.e_ln_b; + + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers_encoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_audio_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); + + // map by name + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["encoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + } + } + + // decoder + { + model.d_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_text_state, n_text_ctx); + + model.d_te = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_vocab); + + model.d_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + model.d_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.positional_embedding"] = model.d_pe; + + model.tensors["decoder.token_embedding.weight"] = model.d_te; + + model.tensors["decoder.ln.weight"] = model.d_ln_w; + model.tensors["decoder.ln.bias"] = model.d_ln_b; + + for (int i = 0; i < n_text_layer; ++i) { + auto & layer = model.layers_decoder[i]; + + layer.mlp_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.mlp_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.mlp_0_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, 4*n_text_state); + layer.mlp_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 4*n_text_state); + + layer.mlp_1_w = ggml_new_tensor_2d(ctx, wtype, 4*n_text_state, n_text_state); + layer.mlp_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_0_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + layer.cross_attn_ln_0_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_q_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_q_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_k_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + + layer.cross_attn_v_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_v_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + layer.cross_attn_ln_1_w = ggml_new_tensor_2d(ctx, wtype, n_text_state, n_text_state); + layer.cross_attn_ln_1_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_text_state); + + // map by name + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.weight"] = layer.mlp_ln_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp_ln.bias"] = layer.mlp_ln_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.weight"] = layer.mlp_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.0.bias"] = layer.mlp_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.weight"] = layer.mlp_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".mlp.2.bias"] = layer.mlp_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.weight"] = layer.attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn_ln.bias"] = layer.attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.weight"] = layer.attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.query.bias"] = layer.attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.key.weight"] = layer.attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.weight"] = layer.attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.value.bias"] = layer.attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.weight"] = layer.attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".attn.out.bias"] = layer.attn_ln_1_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.weight"] = layer.cross_attn_ln_0_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn_ln.bias"] = layer.cross_attn_ln_0_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.weight"] = layer.cross_attn_q_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.query.bias"] = layer.cross_attn_q_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.key.weight"] = layer.cross_attn_k_w; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.weight"] = layer.cross_attn_v_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.value.bias"] = layer.cross_attn_v_b; + + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.weight"] = layer.cross_attn_ln_1_w; + model.tensors["decoder.blocks." + std::to_string(i) + ".cross_attn.out.bias"] = layer.cross_attn_ln_1_b; + } + } + } + + wctx.backend = whisper_backend_init(wctx.params); + if (!wctx.backend) { + WHISPER_LOG_ERROR("%s: failed to initialize the backend\n", __func__); + return false; + } + + // allocate tensors in the backend buffers + model.buffer = ggml_backend_alloc_ctx_tensors(model.ctx, wctx.backend); + if (!model.buffer) { + WHISPER_LOG_ERROR("%s: failed to allocate memory for the model\n", __func__); + return false; + } + + size_t size_main = ggml_backend_buffer_get_size(model.buffer); + WHISPER_LOG_INFO("%s: %8s total size = %8.2f MB\n", __func__, ggml_backend_name(wctx.backend), size_main / 1e6); + + // load weights + { + size_t total_size = 0; + + model.n_loaded = 0; + + std::vector read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (model.tensors.find(name) == model.tensors.end()) { + WHISPER_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = model.tensors[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + //ggml_backend_t backend = wctx.backend; + + //printf("%s: [%5.5s] %s\n", __func__, ggml_backend_name(backend), name.c_str()); + + if (ggml_backend_buffer_is_host(model.buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + //printf("%48s - [%5d, %5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ne[2], ggml_type_name((ggml_type) ttype), ggml_nbytes(tensor)/1e6); + total_size += ggml_nbytes(tensor); + model.n_loaded++; + } + + WHISPER_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (model.n_loaded == 0) { + WHISPER_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (model.n_loaded != (int) model.tensors.size()) { + WHISPER_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), model.n_loaded); + return false; + } + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +static bool whisper_encode_external(const whisper_state & wstate) { + GGML_UNUSED(wstate); + +#ifndef WHISPER_USE_COREML + const bool use_coreml = false; +#else + const bool use_coreml = wstate.ctx_coreml != nullptr; +#endif + +#ifndef WHISPER_USE_OPENVINO + const bool use_openvino = false; +#else + const bool use_openvino = wstate.ctx_openvino != nullptr; +#endif + + return use_coreml || use_openvino; +} + +static struct ggml_cgraph * whisper_build_graph_conv( + whisper_context & wctx, + whisper_state & wstate) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; GGML_UNUSED(n_state); + + const int n_mels = hparams.n_mels; + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_conv.meta.size(), + /*.mem_buffer =*/ wstate.alloc_conv.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + struct ggml_tensor * cur = nullptr; + + if (!whisper_encode_external(wstate)) { + // convolution + gelu + { + cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); + cur = ggml_add(ctx0, cur, model.e_conv_1_b); + + cur = ggml_gelu(ctx0, cur); + + cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); + cur = ggml_add(ctx0, cur, model.e_conv_2_b); + + cur = ggml_gelu(ctx0, cur); + } + + ggml_set_name(cur, "embd_conv"); + wstate.embd_conv = cur; + } else { + ggml_build_forward_expand(gf, mel); + + cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx); + + ggml_set_name(cur, "embd_enc"); + wstate.embd_enc = cur; + } + + ggml_set_output(cur); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * whisper_build_graph_encoder( + whisper_context & wctx, + whisper_state & wstate) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + const int n_layer = hparams.n_audio_layer; + + const int n_state_head = n_state/n_head; + + auto & kv_pad = wstate.kv_pad; + + WHISPER_ASSERT(!!kv_pad.ctx); + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_encode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + + struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); + + const float KQscale = 1.0f/sqrtf(float(n_state_head)); + + // =================================================================== + // NOTE: experimenting with partial evaluation of the encoder (ignore) + //static int iter = -1; + //const int n_iter = 1500/n_ctx; + + //iter = (iter + 1) % n_iter; + + //if (iter == 0) { + // memset(model.memory_cross_k->data, 0, ggml_nbytes(model.memory_cross_k)); + // memset(model.memory_cross_v->data, 0, ggml_nbytes(model.memory_cross_v)); + //} + + static int iter = 0; + + const size_t e_pe_stride = model.e_pe->ne[0]*ggml_element_size(model.e_pe); + const size_t e_pe_offset = model.e_pe->ne[0]*ggml_element_size(model.e_pe)*n_ctx*iter; + + struct ggml_tensor * e_pe = ggml_view_2d(ctx0, model.e_pe, model.e_pe->ne[0], n_ctx, e_pe_stride, e_pe_offset); + cur = ggml_add(ctx0, e_pe, ggml_cont(ctx0, ggml_transpose(ctx0, cur))); + + // =================================================================== + + // original: + //cur = ggml_add(ctx0, model.e_pe, ggml_transpose(ctx0, cur)); + + struct ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_encoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, Qcur, layer.attn_q_b); + + //Qcur = ggml_scale(ctx0, Qcur, pow(float(n_state_head), -0.25)); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + //Kcur = ggml_scale(ctx0, Kcur, pow(float(n_state_head), -0.25)); + + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, Vcur, layer.attn_v_b); + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + if (wctx.params.flash_attn) { + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, ggml_view_1d(ctx0, kv_pad.k, n_ctx*n_state, 0))); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, ggml_view_1d(ctx0, kv_pad.v, n_ctx*n_state, 0))); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_pad.k, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.k)*n_state, + ggml_element_size(kv_pad.k)*n_state_head, + 0); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_pad.v, + n_state_head, n_ctx_pad, n_head, + ggml_element_size(kv_pad.v)*n_state, + ggml_element_size(kv_pad.v)*n_state_head, + 0); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); + } else { + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Kcur, + ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + struct ggml_tensor * V = + ggml_cpy(ctx0, + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + Vcur, + n_state_head, n_head, n_ctx), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, layer.mlp_ln_w), + layer.mlp_ln_b); + } + +#ifdef WHISPER_USE_FLASH_FF + cur = ggml_flash_ff(ctx0, + ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)), + layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b); +#else + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, cur, layer.mlp_1_b); +#endif + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + // cur = ln_f_g*cur + ln_f_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, cur, model.e_ln_w), + model.e_ln_b); + } + + ggml_build_forward_expand(gf, cur); + + wstate.embd_enc = cur; + + //ggml_graph_print(gf); + + //////////////////////////////////////////////////////////////////////////// + + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); + + ggml_free(ctx0); + + return gf; +} + +// pre-compute cross-attention memory +static struct ggml_cgraph * whisper_build_graph_cross( + whisper_context & wctx, + whisper_state & wstate) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + const int n_state = hparams.n_audio_state; + const int n_head = hparams.n_audio_head; + + const int n_state_head = n_state/n_head; + + const int n_ctx_pad = GGML_PAD(n_ctx, 256); + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_cross.meta.size(), + /*.mem_buffer =*/ wstate.alloc_cross.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); + + const float Kscale = pow(float(n_state_head), -0.25); + + for (int il = 0; il < model.hparams.n_text_layer; ++il) { + auto & layer = model.layers_decoder[il]; + + struct ggml_tensor * Kcross = ggml_mul_mat(ctx0, + layer.cross_attn_k_w, + cur); + + Kcross = ggml_scale(ctx0, Kcross, Kscale); + + struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, + layer.cross_attn_v_w, + cur); + + Vcross = ggml_add(ctx0, + Vcross, + layer.cross_attn_v_b); + + struct ggml_tensor * k; + struct ggml_tensor * v; + + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx_pad)); + + v = ggml_view_1d(ctx0, wstate.kv_cross.v, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.v)*n_state)*(il*n_ctx_pad)); + } else { + Vcross = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcross, n_state, n_ctx)); + + k = ggml_view_1d(ctx0, wstate.kv_cross.k, n_state*n_ctx, + (ggml_element_size(wstate.kv_cross.k)*n_state)*(il*n_ctx)); + + v = ggml_view_2d(ctx0, wstate.kv_cross.v, n_ctx, n_state, + ( n_ctx)*ggml_element_size(wstate.kv_cross.v), + (il*n_ctx)*ggml_element_size(wstate.kv_cross.v)*n_state); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcross, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcross, v)); + } + + //ggml_graph_print(gf); + + ggml_free(ctx0); + + return gf; +} + +// evaluate the encoder with the given state +// +// given audio recording (more specifically, its log mel spectrogram), runs forward pass of the encoder +// part of the transformer model and returns the encoded features +// +// - wctx: the model +// - wstate: the state of the encoder +// - n_threads: number of threads to use +// - mel_offset: offset in the mel spectrogram (i.e. audio offset) +// +static bool whisper_encode_internal( + whisper_context & wctx, + whisper_state & wstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + // conv + { + auto & alloc = wstate.alloc_conv.alloc; + + ggml_cgraph * gf = whisper_build_graph_conv(wctx, wstate); + + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + // set the input + { + const auto & mel_inp = wstate.mel; + const int n_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == wctx.model.hparams.n_mels); + + wstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = wstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + 2*n_ctx, mel_inp.n_len); + + for (int j = 0; j < mel_inp.n_mel; ++j) { + for (int i = i0; i < i1; ++i) { + dst[j*2*n_ctx + (i - i0)] = mel_inp.data[j*mel_inp.n_len + i]; + } + } + + ggml_backend_tensor_set(mel, wstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + if (!whisper_encode_external(wstate)) { + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } + } else { +#if defined(WHISPER_USE_COREML) + whisper_coreml_encode(wstate.ctx_coreml, mel->ne[0], mel->ne[1], (float *) mel->data, (float *) wstate.embd_enc->data); +#elif defined(WHISPER_USE_OPENVINO) + whisper_openvino_encode(wstate.ctx_openvino, mel, wstate.embd_enc); +#endif + } + } + + // encoder + if (!whisper_encode_external(wstate)) { + auto & alloc = wstate.alloc_encode.alloc; + + ggml_cgraph * gf = whisper_build_graph_encoder(wctx, wstate); + + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } + } + + // cross + { + auto & alloc = wstate.alloc_cross.alloc; + + ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); + + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } + } + + wstate.t_encode_us += ggml_time_us() - t_start_us; + wstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static struct ggml_cgraph * whisper_build_graph_decoder( + whisper_context & wctx, + whisper_state & wstate, + const whisper_batch & batch, + bool save_alignment_heads_QKs, + bool worst_case) { + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + auto & kv_self = wstate.kv_self; + + WHISPER_ASSERT(!!kv_self.ctx); + + const int n_ctx = kv_self.size; + const int n_state = hparams.n_text_state; + const int n_head = hparams.n_text_head; + const int n_layer = hparams.n_text_layer; + + const int n_state_head = n_state/n_head; + + const int n_tokens = batch.n_tokens; + const int n_audio_ctx = wstate.exp_n_audio_ctx > 0 ? wstate.exp_n_audio_ctx : hparams.n_audio_ctx; + + const int n_audio_ctx_pad = GGML_PAD(n_audio_ctx, 256); + + const int32_t n_kv = worst_case ? n_ctx : kv_self.n; + const int32_t kv_head = worst_case ? n_ctx - n_tokens : kv_self.head; + + //WHISPER_LOG_DEBUG("%s: n_past = %d, n_tokens = %d, n_audio_ctx = %d, n_ctx = %d\n", __func__, n_past, n_tokens, n_audio_ctx, n_ctx); + + struct ggml_init_params params = { + /*.mem_size =*/ wstate.alloc_decode.meta.size(), + /*.mem_buffer =*/ wstate.alloc_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(embd, "embd"); + ggml_set_input(embd); + + struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(position, "position"); + ggml_set_input(position); + + const float KQscale = pow(float(n_state_head), -0.25); + + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_set_input(KQ_mask); + + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + + // token encoding + position encoding + struct ggml_tensor * cur = + ggml_add(ctx0, + ggml_get_rows(ctx0, model.d_te, embd), + ggml_get_rows(ctx0, model.d_pe, position)); + + struct ggml_tensor * inpL = cur; + + // [EXPERIMENTAL] Token-level timestamps with DTW + struct ggml_tensor * aheads_cross_QKs = nullptr; + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers_decoder[il]; + + // norm + { + cur = ggml_norm(ctx0, inpL, hparams.eps); + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.attn_ln_0_w), + layer.attn_ln_0_b); + } + + // self-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.attn_q_b); + + Qcur = ggml_scale(ctx0, Qcur, KQscale); + + // note: no bias for Key + struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, + layer.attn_k_w, + cur); + + Kcur = ggml_scale(ctx0, Kcur, KQscale); + + // store key and value to memory + { + struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, + layer.attn_v_w, + cur); + + Vcur = ggml_add(ctx0, + Vcur, + layer.attn_v_b); + + struct ggml_tensor * k; + struct ggml_tensor * v; + + if (wctx.params.flash_attn) { + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_1d(ctx0, kv_self.v, n_tokens*n_state, + (ggml_element_size(kv_self.v)*n_state)*(il*n_ctx + kv_head)); + } else { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_state, n_tokens)); + + k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_state, + (ggml_element_size(kv_self.k)*n_state)*(il*n_ctx + kv_head)); + + v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_state, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_state + kv_head*ggml_element_size(kv_self.v)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + // ------ + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.k)*n_state, + ggml_element_size(kv_self.k)*n_state_head, + ggml_element_size(kv_self.k)*n_state*n_ctx*il); + + if (wctx.params.flash_attn) { + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_state_head, n_kv, n_head, + ggml_element_size(kv_self.v)*n_state, + ggml_element_size(kv_self.v)*n_state_head, + ggml_element_size(kv_self.v)*n_state*n_ctx*il); + + cur = ggml_flash_attn_ext(ctx0, Q, K, V, KQ_mask_f16, 1.0f, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask, 1.0f, 0.0f); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_state_head, n_head, + n_ctx*ggml_element_size(kv_self.v), + n_ctx*ggml_element_size(kv_self.v)*n_state_head, + n_ctx*ggml_element_size(kv_self.v)*n_state*il); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.attn_ln_1_b); + } + + // add the input + struct ggml_tensor * inpCA = ggml_add(ctx0, cur, inpL); + + // norm + { + cur = ggml_norm(ctx0, inpCA, hparams.eps); // note: we use inpCA here + + // cur = ln_0_w*cur + ln_0_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.cross_attn_ln_0_w), + layer.cross_attn_ln_0_b); + } + + // cross-attention + { + struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, + layer.cross_attn_q_w, + cur); + + Qcur = ggml_add(ctx0, + Qcur, + layer.cross_attn_q_b); + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_tokens), + 0, 2, 1, 3); + + if (wctx.params.flash_attn) { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx_pad*il); + + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_state_head, n_audio_ctx_pad, n_head, + ggml_element_size(wstate.kv_cross.v)*n_state, + ggml_element_size(wstate.kv_cross.v)*n_state_head, + ggml_element_size(wstate.kv_cross.v)*n_state*n_audio_ctx_pad*il); + + cur = ggml_flash_attn_ext(ctx0, Q, Kcross, Vcross, nullptr, KQscale, 0.0f); + + cur = ggml_reshape_2d(ctx0, cur, n_state, n_tokens); + } else { + struct ggml_tensor * Kcross = + ggml_view_3d(ctx0, wstate.kv_cross.k, + n_state_head, n_audio_ctx, n_head, + ggml_element_size(wstate.kv_cross.k)*n_state, + ggml_element_size(wstate.kv_cross.k)*n_state_head, + ggml_element_size(wstate.kv_cross.k)*n_state*n_audio_ctx*il); + + struct ggml_tensor * Vcross = + ggml_view_3d(ctx0, wstate.kv_cross.v, + n_audio_ctx, n_state_head, n_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v), + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state_head, + n_audio_ctx*ggml_element_size(wstate.kv_cross.v)*n_state*il); + + // ------ + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, Kcross, Q); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps) { + if (wstate.aheads_masks.m[il] != nullptr) { + struct ggml_tensor * aheads_KQs = ggml_reshape_2d(ctx0, KQ_soft_max, KQ_soft_max->ne[0] * KQ_soft_max->ne[1], KQ_soft_max->ne[2]); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_mul_mat(ctx0, wstate.aheads_masks.m[il], aheads_KQs); + aheads_KQs = ggml_transpose(ctx0, aheads_KQs); + aheads_KQs = ggml_cont(ctx0, aheads_KQs); + aheads_KQs = ggml_reshape_3d(ctx0, aheads_KQs, KQ_soft_max->ne[0], KQ_soft_max->ne[1], wstate.aheads_masks.m[il]->ne[1]); + if (aheads_cross_QKs == NULL) { + aheads_cross_QKs = aheads_KQs; + } else { + aheads_cross_QKs = ggml_concat(ctx0, aheads_cross_QKs, aheads_KQs,2); + } + } + } + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, Vcross, KQ_soft_max); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens)); + } + } + + // projection + { + cur = ggml_mul_mat(ctx0, + layer.cross_attn_ln_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.cross_attn_ln_1_b); + } + + // add the input + cur = ggml_add(ctx0, cur, inpCA); + + struct ggml_tensor * inpFF = cur; + + // feed-forward network + { + // norm + { + cur = ggml_norm(ctx0, inpFF, hparams.eps); + + // cur = mlp_ln_w*cur + mlp_ln_b + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + layer.mlp_ln_w), + layer.mlp_ln_b); + } + + // fully connected + cur = ggml_mul_mat(ctx0, + layer.mlp_0_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_0_b); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + cur = ggml_mul_mat(ctx0, + layer.mlp_1_w, + cur); + + cur = ggml_add(ctx0, + cur, + layer.mlp_1_b); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, hparams.eps); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + cur, + model.d_ln_w), + model.d_ln_b); + } + + // compute logits only for the last token + // comment this line to compute logits for all n_tokens + // might be useful in the future + //cur = ggml_view_2d(ctx0, cur, cur->ne[0], 1, cur->nb[1], (cur->ne[1] - 1)*cur->nb[1]); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.d_te, cur); + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (wctx.params.dtw_token_timestamps && aheads_cross_QKs != nullptr) { + aheads_cross_QKs = ggml_transpose(ctx0, aheads_cross_QKs); + aheads_cross_QKs = ggml_cont(ctx0, aheads_cross_QKs); + if (save_alignment_heads_QKs) { + ggml_build_forward_expand(gf, aheads_cross_QKs); + wstate.aheads_cross_QKs = aheads_cross_QKs; + } + } + + ggml_build_forward_expand(gf, logits); + + ggml_free(ctx0); + + return gf; +} + +// evaluate the decoder +// +// given text prompt + audio features -> computes the logits for the next token +// +// - model: the model +// - n_threads: number of threads to use +// - tokens: text prompt +// - n_tokens: number of tokens in the prompt +// - n_past: number of past tokens to prefix the prompt with +// +static bool whisper_decode_internal( + whisper_context & wctx, + whisper_state & wstate, + const whisper_batch & batch, + const int n_threads, + bool save_alignment_heads_QKs, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = wctx.model; + const auto & hparams = model.hparams; + + const int n_vocab = hparams.n_vocab; + const int n_tokens = batch.n_tokens; + + auto & logits_out = wstate.logits; + + struct ggml_tensor * logits; + + // find KV slot for the batch + { + auto & kv_self = wstate.kv_self; + + if (!whisper_kv_cache_find_slot(kv_self, batch)) { + return false; + } + + const uint32_t pad = whisper_kv_cache_get_padding(wctx); + kv_self.n = std::min(kv_self.size, std::max(pad, GGML_PAD(whisper_kv_cache_cell_max(kv_self), pad))); + + //kv_self.n = std::min((int32_t) hparams.n_text_ctx, std::max(32, whisper_kv_cache_cell_max(kv_self))); + //printf("n_tokens = %5d, kv_self.head = %5d, kv_self.n = %5d, seq_id = %5d\n", batch.n_tokens, kv_self.head, kv_self.n, batch.seq_id[0][0]); + } + + // decoder + { + auto & alloc = wstate.alloc_decode.alloc; + + ggml_cgraph * gf = whisper_build_graph_decoder(wctx, wstate, batch, save_alignment_heads_QKs, false); + + if (!ggml_gallocr_alloc_graph(alloc, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set the inputs + { + struct ggml_tensor * embd = ggml_graph_get_tensor(gf, "embd"); + ggml_backend_tensor_set(embd, batch.token, 0, n_tokens*ggml_element_size(embd)); + } + + { + struct ggml_tensor * position = ggml_graph_get_tensor(gf, "position"); + for (int i = 0; i < n_tokens; ++i) { + const int32_t val = batch.pos[i]; + ggml_backend_tensor_set(position, &val, i*sizeof(int32_t), sizeof(int32_t)); + } + } + + { + struct ggml_tensor * KQ_mask = ggml_graph_get_tensor(gf, "KQ_mask"); + + auto & kv_self = wstate.kv_self; + + const int32_t n_kv = kv_self.n; + + wstate.inp_mask.resize(ggml_nelements(KQ_mask)); + + float * data = wstate.inp_mask.data(); + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const whisper_pos pos = batch.pos[j]; + const whisper_seq_id seq_id = batch.seq_id[j][0]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + + for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { + for (int j = 0; j < n_kv; ++j) { + data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + } + } + } + + ggml_backend_tensor_set(KQ_mask, wstate.inp_mask.data(), 0, ggml_nelements(KQ_mask)*sizeof(float)); + } + + logits = gf->nodes[gf->n_nodes - 1]; + + if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) { + return false; + } + } + + logits_out.resize(n_tokens*n_vocab); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_vocab*i), sizeof(float)*(n_vocab*i), sizeof(float)*n_vocab); + } + + if (batch.n_tokens > 1) { + //printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__, + // ggml_used_mem(ctx0)/1e6, + // wstate.get_buf_max_mem(0)/1e6, + // wstate.get_buf_max_mem(1)/1e6, + // wstate.get_buf_max_mem(2)/1e6, + // wstate.get_buf_max_mem(3)/1e6); + } + + if (batch.n_tokens == 1) { + wstate.t_decode_us += ggml_time_us() - t_start_us; + wstate.n_decode++; + } else if (batch.n_tokens < 16) { + wstate.t_batchd_us += ggml_time_us() - t_start_us; + wstate.n_batchd += n_tokens; + } else { + wstate.t_prompt_us += ggml_time_us() - t_start_us; + wstate.n_prompt += n_tokens; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma) { + int64_t msec = t * 10; + int64_t hr = msec / (1000 * 60 * 60); + msec = msec - hr * (1000 * 60 * 60); + int64_t min = msec / (1000 * 60); + msec = msec - min * (1000 * 60); + int64_t sec = msec / 1000; + msec = msec - sec * 1000; + + char buf[32]; + snprintf(buf, sizeof(buf), "%02d:%02d:%02d%s%03d", (int) hr, (int) min, (int) sec, comma ? "," : ".", (int) msec); + + return std::string(buf); +} + +#define SIN_COS_N_COUNT WHISPER_N_FFT +static float sin_vals[SIN_COS_N_COUNT]; +static float cos_vals[SIN_COS_N_COUNT]; + +// In FFT, we frequently use sine and cosine operations with the same values. +// We can use precalculated values to speed up the process. +static void fill_sin_cos_table() { + static bool is_filled = false; + if (is_filled) return; + for (int i = 0; i < SIN_COS_N_COUNT; i++) { + double theta = (2*M_PI*i)/SIN_COS_N_COUNT; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + is_filled = true; +} + +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const std::vector & in, std::vector & out) { + int N = in.size(); + + out.resize(N*2); + const int sin_cos_step = SIN_COS_N_COUNT / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % (SIN_COS_N_COUNT); // t = 2*M_PI*k*n/N + re += in[n]*cos_vals[idx]; // cos(t) + im -= in[n]*sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(const std::vector & in, std::vector & out) { + out.resize(in.size()*2); + + int N = in.size(); + + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + if (N%2 == 1) { + dft(in, out); + return; + } + + std::vector even; + std::vector odd; + + even.reserve(N/2); + odd.reserve(N/2); + + for (int i = 0; i < N; i++) { + if (i % 2 == 0) { + even.push_back(in[i]); + } else { + odd.push_back(in[i]); + } + } + + std::vector even_fft; + std::vector odd_fft; + + fft(even, even_fft); + fft(odd, odd_fft); + + const int sin_cos_step = SIN_COS_N_COUNT / N; + for (int k = 0; k < N/2; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cos_vals[idx]; // cos(t) + float im = -sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + N/2) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + N/2) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +static bool hann_window(int length, bool periodic, std::vector & output) { + if (output.size() < static_cast(length)) { + output.resize(length); + } + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5*(1.0 - cosf((2.0*M_PI*i)/(length + offset))); + } + + return true; +} + +static void log_mel_spectrogram_worker_thread(int ith, const std::vector & hann, const std::vector & samples, + int n_samples, int frame_size, int frame_step, int n_threads, + const whisper_filters & filters, whisper_mel & mel) { + std::vector fft_in(frame_size, 0.0); + std::vector fft_out(2 * frame_size); + int n_fft = filters.n_fft; + int i = ith; + + // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist + assert(n_fft == 1 + (frame_size / 2)); + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(n_samples / frame_step + 1, mel.n_len); i += n_threads) { + const int offset = i * frame_step; + + // apply Hanning window (~10% faster) + for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + fft_in[j] = hann[j] * samples[offset + j]; + } + // fill the rest with zeros + if (n_samples - offset < frame_size) { + std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + } + + // FFT + fft(fft_in, fft_out); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fft; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fft - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fft + k + 0] + + fft_out[k + 1] * filters.data[j * n_fft + k + 1] + + fft_out[k + 2] * filters.data[j * n_fft + k + 2] + + fft_out[k + 3] * filters.data[j * n_fft + k + 3]; + } + + // handle n_fft remainder + for (; k < n_fft; k++) { + sum += fft_out[k] * filters.data[j * n_fft + k]; + } + + sum = log10(std::max(sum, 1e-10)); + + mel.data[j * mel.n_len + i] = sum; + } + } + + // Otherwise fft_out are all zero + double sum = log10(1e-10); + for (; i < mel.n_len; i += n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[j * mel.n_len + i] = sum; + } + } +} + +// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L110-L157 +static bool log_mel_spectrogram( + whisper_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const whisper_filters & filters, + const bool debug, + whisper_mel & mel) { + const int64_t t_start_us = ggml_time_us(); + + // Hanning window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector hann; + hann_window(frame_size, true, hann); + + + // Calculate the length of padding + int64_t stage_1_pad = WHISPER_SAMPLE_RATE * 30; + int64_t stage_2_pad = frame_size / 2; + + // Initialize a vector and copy data from C array to it. + std::vector samples_padded; + samples_padded.resize(n_samples + stage_1_pad + stage_2_pad * 2); + std::copy(samples, samples + n_samples, samples_padded.begin() + stage_2_pad); + + // pad 30 seconds of zeros at the end of audio (480,000 samples) + reflective pad 200 samples at the end of audio + std::fill(samples_padded.begin() + n_samples + stage_2_pad, samples_padded.begin() + n_samples + stage_1_pad + 2 * stage_2_pad, 0); + + // reflective pad 200 samples at the beginning of audio + std::reverse_copy(samples + 1, samples + 1 + stage_2_pad, samples_padded.begin()); + + mel.n_mel = n_mel; + // https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/SpectralOps.cpp#L936 + // Calculate number of frames + remove the last frame + mel.n_len = (samples_padded.size() - frame_size) / frame_step; + // Calculate semi-padded sample length to ensure compatibility + mel.n_len_org = 1 + (n_samples + stage_2_pad - frame_size) / frame_step; + mel.data.resize(mel.n_mel * mel.n_len); + + + { + std::vector workers(n_threads - 1); + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw] = std::thread( + log_mel_spectrogram_worker_thread, iw + 1, std::cref(hann), samples_padded, + n_samples + stage_2_pad, frame_size, frame_step, n_threads, + std::cref(filters), std::ref(mel)); + } + + // main thread + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples + stage_2_pad, frame_size, frame_step, n_threads, filters, mel); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + // clamping and normalization + double mmax = -1e20; + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] > mmax) { + mmax = mel.data[i]; + } + } + + mmax -= 8.0; + + for (int i = 0; i < mel.n_mel*mel.n_len; i++) { + if (mel.data[i] < mmax) { + mel.data[i] = mmax; + } + + mel.data[i] = (mel.data[i] + 4.0)/4.0; + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + // Dump log_mel_spectrogram + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +// split text into tokens +// +// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 +// +// Regex (Python): +// r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""" +// +// Regex (C++): +// R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)" +// +static std::vector tokenize(const whisper_vocab & vocab, const std::string & text) { + std::vector words; + + // first split the text into words + { + std::string str = text; + std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; + + std::regex re(pat); + std::smatch m; + + while (std::regex_search(str, m, re)) { + for (auto x : m) { + words.push_back(x); + } + str = m.suffix(); + } + } + + // find the longest tokens that form the words: + std::vector tokens; + for (const auto & word : words) { + if (word.empty()) continue; + + int i = 0; + int n = word.size(); + while (i < n) { + int j = n; + bool found = false; + while (j > i) { + auto sub = word.substr(i, j-i); + auto it = vocab.token_to_id.find(sub); + if (it != vocab.token_to_id.end()) { + tokens.push_back(it->second); + i = j; + found = true; + break; + } + --j; + } + if (!found) { + WHISPER_LOG_ERROR("unknown token\n"); + ++i; + } + } + } + + return tokens; +} + +// +// interface implementation +// + +#ifdef WHISPER_USE_COREML +// replace .bin with -encoder.mlmodelc +static std::string whisper_get_coreml_path_encoder(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + // match "-qx_x" + pos = path_bin.rfind('-'); + if (pos != std::string::npos) { + auto sub = path_bin.substr(pos); + if (sub.size() == 5 && sub[1] == 'q' && sub[3] == '_') { + path_bin = path_bin.substr(0, pos); + } + } + + path_bin += "-encoder.mlmodelc"; + + return path_bin; +} +#endif + +#ifdef WHISPER_USE_OPENVINO +// replace .bin with-encoder-openvino.xml +static std::string whisper_openvino_get_path_encoder(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino.xml"; + + return path_bin; +} + +static std::string whisper_openvino_get_path_cache(std::string path_bin) { + auto pos = path_bin.rfind('.'); + if (pos != std::string::npos) { + path_bin = path_bin.substr(0, pos); + } + + path_bin += "-encoder-openvino-cache"; + + return path_bin; +} +#endif + +struct whisper_state * whisper_init_state(whisper_context * ctx) { + fill_sin_cos_table(); + + whisper_state * state = new whisper_state; + + state->backend = whisper_backend_init(ctx->params); + if (!state->backend) { + WHISPER_LOG_ERROR("%s: whisper_backend_init() failed\n", __func__); + whisper_free_state(state); + return nullptr; + } + + // at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx + // in theory, there can be a case where this is not enough, but in practice it should always be enough + const int factor = 3; + + if (!kv_cache_init(state->kv_self, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_self.k) + ggml_nbytes(state->kv_self.v); + WHISPER_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + if (!kv_cache_init(state->kv_cross, ctx->backend, ctx->itype, + ctx->model.hparams.n_text_state, + ctx->model.hparams.n_text_layer, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for cross-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_cross.k) + ggml_nbytes(state->kv_cross.v); + WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + if (!kv_cache_init(state->kv_pad, ctx->backend, ctx->itype, + ctx->model.hparams.n_audio_state, + 1, + GGML_PAD(ctx->model.hparams.n_audio_ctx, 256))) { + WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); + whisper_free_state(state); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(state->kv_pad.k) + ggml_nbytes(state->kv_pad.v); + WHISPER_LOG_INFO("%s: kv pad size = %7.2f MB\n", __func__, memory_size / 1e6); + } + + // [EXPERIMENTAL] Token-level timestamps with DTW + if (ctx->params.dtw_token_timestamps) { + if (!aheads_masks_init(ctx->params, ctx->model.hparams, state->aheads_masks, ctx->backend)) { + WHISPER_LOG_ERROR("%s: aheads_masks_init() failed for alignment heads masks\n", __func__); + whisper_free_state(state); + return nullptr; + } + const size_t memory_size = aheads_masks_nbytes(state->aheads_masks); + WHISPER_LOG_INFO("%s: alignment heads masks size = %ld B\n", __func__, memory_size); + } + +#ifdef WHISPER_USE_COREML + const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); + + WHISPER_LOG_INFO("%s: loading Core ML model from '%s'\n", __func__, path_coreml.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); + + state->ctx_coreml = whisper_coreml_init(path_coreml.c_str()); + if (!state->ctx_coreml) { + WHISPER_LOG_ERROR("%s: failed to load Core ML model from '%s'\n", __func__, path_coreml.c_str()); +#ifndef WHISPER_COREML_ALLOW_FALLBACK + whisper_free_state(state); + return nullptr; +#endif + } else { + WHISPER_LOG_INFO("%s: Core ML model loaded\n", __func__); + } +#endif + + state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); + + state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); + + // TAGS: WHISPER_DECODER_INIT + state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); + + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); + + state->decoders[0].rng = std::mt19937(0); + + // conv allocator + { + bool ok = whisper_allocr_graph_init(state->alloc_conv, ctx->backend, + [&]() { + return whisper_build_graph_conv(*ctx, *state); + }); + + if (!ok) { + WHISPER_LOG_ERROR("%s: failed to init conv allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (conv) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_conv) / 1e6); + } + + // encoder allocator + if (!whisper_encode_external(*state)) { + bool ok = whisper_allocr_graph_init(state->alloc_encode, ctx->backend, + [&]() { + return whisper_build_graph_encoder(*ctx, *state); + }); + + if (!ok) { + WHISPER_LOG_ERROR("%s: failed to init encoder allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_encode) / 1e6); + } + + // cross allocator + { + bool ok = whisper_allocr_graph_init(state->alloc_cross, ctx->backend, + [&]() { + return whisper_build_graph_cross(*ctx, *state); + }); + + if (!ok) { + WHISPER_LOG_ERROR("%s: failed to init cross allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (cross) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_cross) / 1e6); + } + + // decoder allocator + { + bool ok = whisper_allocr_graph_init(state->alloc_decode, ctx->backend, + [&]() { + const auto & hparams = ctx->model.hparams; + + // TODO: make sure this is the worst-case scenario + const int n_tokens = hparams.n_text_ctx; + const int n_past = 0; + + whisper_batch_prep_legacy(state->batch, nullptr, n_tokens, n_past, 0); + + return whisper_build_graph_decoder(*ctx, *state, state->batch, ctx->params.dtw_token_timestamps, true); + }); + + if (!ok) { + WHISPER_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + whisper_free_state(state); + return nullptr; + } + + WHISPER_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, whisper_allocr_size(state->alloc_decode) / 1e6); + } + + return state; +} + +int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir) { +#ifndef WHISPER_USE_OPENVINO + (void)(ctx); + (void)(model_path); + (void)(device); + (void)(cache_dir); + + return 1; +#else + if (!model_path && ctx->path_model.empty()) { + WHISPER_LOG_ERROR("%s: model_path is nullptr, and ctx has no model_path set.\n", __func__); + return 1; + } + + std::string path_encoder; + if (!model_path) { + //if model_path is not set, attempt to find it in the same directory as ggml-.bin model + path_encoder = whisper_openvino_get_path_encoder(ctx->path_model); + } else { + path_encoder = model_path; + } + + std::string path_cache; + if (!cache_dir) { + //if cache_dir is not set, set it as a dir residing next to ggml-.bin + path_cache = whisper_openvino_get_path_cache(ctx->path_model); + } else { + path_cache = cache_dir; + } + + WHISPER_LOG_INFO("%s: loading OpenVINO model from '%s'\n", __func__, path_encoder.c_str()); + WHISPER_LOG_INFO("%s: first run on a device may take a while ...\n", __func__); + + ctx->state->ctx_openvino = whisper_openvino_init(path_encoder.c_str(), device, path_cache.c_str()); + if (!ctx->state->ctx_openvino) { + WHISPER_LOG_ERROR("%s: failed to init OpenVINO encoder from '%s'\n", __func__, path_encoder.c_str()); + return 1; + } else { + WHISPER_LOG_INFO("%s: OpenVINO model loaded\n", __func__); + } + + return 0; +#endif +} + +struct whisper_context_params whisper_context_default_params() { + struct whisper_context_params result = { + /*.use_gpu =*/ true, + /*.flash_attn =*/ false, + /*.gpu_device =*/ 0, + + /*.dtw_token_timestamps =*/ false, + /*.dtw_aheads_preset =*/ WHISPER_AHEADS_NONE, + /*.dtw_n_top =*/ -1, + /*.dtw_aheads =*/ { + /*.n_heads =*/ 0, + /*.heads =*/ NULL, + }, + /*.dtw_mem_size =*/ 1024*1024*128, + }; + return result; +} + +struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params) { + WHISPER_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + WHISPER_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + whisper_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = whisper_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast(buffer), buffer_size, 0 }; + + WHISPER_LOG_INFO("%s: loading model from buffer\n", __func__); + + whisper_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return whisper_init_with_params_no_state(&loader, params); +} + +struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_loader * loader, struct whisper_context_params params) { + ggml_time_init(); + + if (params.flash_attn && params.dtw_token_timestamps) { + WHISPER_LOG_WARN("%s: dtw_token_timestamps is not supported with flash_attn - disabling\n", __func__); + params.dtw_token_timestamps = false; + } + + WHISPER_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + WHISPER_LOG_INFO("%s: flash attn = %d\n", __func__, params.flash_attn); + WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + WHISPER_LOG_INFO("%s: dtw = %d\n", __func__, params.dtw_token_timestamps); + + whisper_context * ctx = new whisper_context; + ctx->params = params; + + if (!whisper_model_load(loader, *ctx)) { + loader->close(loader->context); + WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + return ctx; +} + +struct whisper_context * whisper_init_from_file_with_params(const char * path_model, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_with_params(struct whisper_model_loader * loader, struct whisper_context_params params) { + whisper_context * ctx = whisper_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = whisper_init_state(ctx); + if (!ctx->state) { + whisper_free(ctx); + return nullptr; + } + + return ctx; +} + +struct whisper_context * whisper_init_from_file(const char * path_model) { + return whisper_init_from_file_with_params(path_model, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size) { + return whisper_init_from_buffer_with_params(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context * whisper_init(struct whisper_model_loader * loader) { + return whisper_init_with_params(loader, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_file_no_state(const char * path_model) { + return whisper_init_from_file_with_params_no_state(path_model, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size) { + return whisper_init_from_buffer_with_params_no_state(buffer, buffer_size, whisper_context_default_params()); +} + +struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader) { + return whisper_init_with_params_no_state(loader, whisper_context_default_params()); +} + +void whisper_free_state(struct whisper_state * state) { + if (state) { + kv_cache_free(state->kv_self); + kv_cache_free(state->kv_cross); + kv_cache_free(state->kv_pad); + +#ifdef WHISPER_USE_COREML + if (state->ctx_coreml != nullptr) { + whisper_coreml_free(state->ctx_coreml); + state->ctx_coreml = nullptr; + } +#endif + +#ifdef WHISPER_USE_OPENVINO + if (state->ctx_openvino != nullptr) { + whisper_openvino_free(state->ctx_openvino); + state->ctx_openvino = nullptr; + } +#endif + + whisper_batch_free(state->batch); + + ggml_gallocr_free(state->alloc_conv.alloc); + ggml_gallocr_free(state->alloc_encode.alloc); + ggml_gallocr_free(state->alloc_cross.alloc); + ggml_gallocr_free(state->alloc_decode.alloc); + + ggml_backend_free(state->backend); + + // [EXPERIMENTAL] Token-level timestamps with DTW + aheads_masks_free(state->aheads_masks); + + delete state; + } +} + +void whisper_free(struct whisper_context * ctx) { + if (ctx) { + ggml_free(ctx->model.ctx); + + ggml_backend_buffer_free(ctx->model.buffer); + + whisper_free_state(ctx->state); + + ggml_backend_free(ctx->backend); + + delete ctx; + } +} + +void whisper_free_context_params(struct whisper_context_params * params) { + if (params) { + delete params; + } +} + +void whisper_free_params(struct whisper_full_params * params) { + if (params) { + delete params; + } +} + +int whisper_pcm_to_mel_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, WHISPER_N_FFT, WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + return whisper_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) +int whisper_pcm_to_mel_phase_vocoder_with_state(struct whisper_context * ctx, struct whisper_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, samples, n_samples, WHISPER_SAMPLE_RATE, 2 * WHISPER_N_FFT, 2 * WHISPER_HOP_LENGTH, ctx->model.filters.n_mel, n_threads, ctx->model.filters, false, state->mel)) { + WHISPER_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +// same as whisper_pcm_to_mel, but applies a Phase Vocoder to speed up the audio x2 (PV without phase lock is not good) +int whisper_pcm_to_mel_phase_vocoder(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads) { + return whisper_pcm_to_mel_phase_vocoder_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +// same as whisper_pcm_to_mel, but applies WSOLA to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies HPTSM to speed up the audio x2 +// TODO + +// same as whisper_pcm_to_mel, but applies PV (with phase lock) to speed up the audio x2 +// TODO + +int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + WHISPER_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel) { + return whisper_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int whisper_encode_with_state(struct whisper_context * ctx, struct whisper_state * state, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_encode(struct whisper_context * ctx, int offset, int n_threads) { + if (!whisper_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int whisper_decode_with_state(struct whisper_context * ctx, struct whisper_state * state, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + whisper_batch_prep_legacy(state->batch, tokens, n_tokens, n_past, 0); + + whisper_kv_cache_seq_rm(state->kv_self, 0, n_past, -1); + + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, false, nullptr, nullptr)) { + WHISPER_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; + } + + return 0; +} + +int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads) { + if (ctx->state == nullptr) { + WHISPER_LOG_ERROR("%s: ERROR state was not loaded.\n", __func__); + return -1; + } + + return whisper_decode_with_state(ctx, ctx->state, tokens, n_tokens, n_past, n_threads); +} + +int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + WHISPER_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int whisper_token_count(struct whisper_context * ctx, const char * text) { + return -whisper_tokenize(ctx, text, NULL, 0); +} + +int whisper_lang_max_id() { + auto max_id = 0; + for (const auto & kv : g_lang) { + max_id = std::max(max_id, kv.second.first); + } + + return max_id; +} + +int whisper_lang_id(const char * lang) { + if (!g_lang.count(lang)) { + for (const auto & kv : g_lang) { + if (kv.second.second == lang) { + return kv.second.first; + } + } + + WHISPER_LOG_ERROR("%s: unknown language '%s'\n", __func__, lang); + return -1; + } + return g_lang.at(lang).first; +} + +const char * whisper_lang_str(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.first.c_str(); + } + } + + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +const char * whisper_lang_str_full(int id) { + for (const auto & kv : g_lang) { + if (kv.second.first == id) { + return kv.second.second.c_str(); + } + } + + WHISPER_LOG_ERROR("%s: unknown language id %d\n", __func__, id); + return nullptr; +} + +int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs) { + const int seek = offset_ms/10; + + if (seek < 0) { + WHISPER_LOG_ERROR("%s: offset %dms is before the start of the audio\n", __func__, offset_ms); + return -1; + } + + if (seek >= state->mel.n_len_org) { + WHISPER_LOG_ERROR("%s: offset %dms is past the end of the audio (%dms)\n", __func__, offset_ms, state->mel.n_len_org*10); + return -2; + } + + // run the encoder + if (whisper_encode_with_state(ctx, state, seek, n_threads) != 0) { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const std::vector prompt = { whisper_token_sot(ctx) }; + + if (whisper_decode_with_state(ctx, state, prompt.data(), prompt.size(), 0, n_threads) != 0) { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + auto & logits_id = state->decoders[0].logits_id; + logits_id.clear(); + + for (const auto & kv : g_lang) { + const auto token_lang = whisper_token_lang(ctx, kv.second.first); + logits_id.emplace_back(state->logits[token_lang], kv.second.first); + } + + // sort descending + { + using pair_type = std::remove_reference::type::value_type; + std::sort(logits_id.begin(), logits_id.end(), [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + // softmax + { + const auto max = logits_id[0].first; + + double sum = 0.0f; + for (auto & kv : logits_id) { + kv.first = exp(kv.first - max); + sum += kv.first; + } + + for (auto & kv : logits_id) { + kv.first /= sum; + } + } + + { + for (const auto & prob : logits_id) { + if (lang_probs) { + lang_probs[prob.second] = prob.first; + } + + //printf("%s: lang %2d (%3s): %f\n", __func__, prob.second, whisper_lang_str(prob.second), prob.first); + } + } + + return logits_id[0].second; +} + +int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs) { + return whisper_lang_auto_detect_with_state(ctx, ctx->state, offset_ms, n_threads, lang_probs); +} + +int whisper_model_n_vocab(struct whisper_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int whisper_model_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_model_n_audio_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int whisper_model_n_audio_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int whisper_model_n_audio_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int whisper_model_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_model_n_text_state(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_state; +} + +int whisper_model_n_text_head(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_head; +} + +int whisper_model_n_text_layer(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_layer; +} + +int whisper_model_n_mels(struct whisper_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int whisper_model_ftype(struct whisper_context * ctx) { + return ctx->model.hparams.ftype; +} + +int whisper_model_type(struct whisper_context * ctx) { + return ctx->model.type; +} + +const char *whisper_model_type_readable(struct whisper_context * ctx) { + switch (ctx->model.type) { + case e_model::MODEL_TINY: + return "tiny"; + case e_model::MODEL_BASE: + return "base"; + case e_model::MODEL_SMALL: + return "small"; + case e_model::MODEL_MEDIUM: + return "medium"; + case e_model::MODEL_LARGE: + return "large"; + default: + return "unknown"; + } +} + +int whisper_n_len_from_state(struct whisper_state * state) { + return state->mel.n_len_org; +} + +int whisper_n_len(struct whisper_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int whisper_n_vocab(struct whisper_context * ctx) { + return ctx->vocab.n_vocab; +} + +int whisper_n_text_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_text_ctx; +} + +int whisper_n_audio_ctx(struct whisper_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int whisper_is_multilingual(struct whisper_context * ctx) { + return ctx->vocab.is_multilingual() ? 1 : 0; +} + +float * whisper_get_logits(struct whisper_context * ctx) { + return ctx->state->logits.data(); +} + +float * whisper_get_logits_from_state(struct whisper_state * state) { + return state->logits.data(); +} + +const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +whisper_token whisper_token_eot(struct whisper_context * ctx) { + return ctx->vocab.token_eot; +} + +whisper_token whisper_token_sot(struct whisper_context * ctx) { + return ctx->vocab.token_sot; +} + +whisper_token whisper_token_solm(struct whisper_context * ctx) { + return ctx->vocab.token_solm; +} + +whisper_token whisper_token_prev(struct whisper_context * ctx) { + return ctx->vocab.token_prev; +} + +whisper_token whisper_token_nosp(struct whisper_context * ctx) { + return ctx->vocab.token_nosp; +} + +whisper_token whisper_token_not(struct whisper_context * ctx) { + return ctx->vocab.token_not; +} + +whisper_token whisper_token_beg(struct whisper_context * ctx) { + return ctx->vocab.token_beg; +} + +whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id) { + return whisper_token_sot(ctx) + 1 + lang_id; +} + +whisper_token whisper_token_translate(struct whisper_context * ctx) { + return ctx->vocab.token_translate; +} + +whisper_token whisper_token_transcribe(struct whisper_context * ctx) { + return ctx->vocab.token_transcribe; +} + +void whisper_print_timings(struct whisper_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + WHISPER_LOG_INFO("\n"); + WHISPER_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_batchd = std::max(1, ctx->state->n_batchd); + const int32_t n_prompt = std::max(1, ctx->state->n_prompt); + + WHISPER_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + WHISPER_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + WHISPER_LOG_INFO("%s: sample time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + WHISPER_LOG_INFO("%s: encode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + WHISPER_LOG_INFO("%s: decode time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + WHISPER_LOG_INFO("%s: batchd time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_batchd_us, n_batchd, 1e-3f * ctx->state->t_batchd_us / n_batchd); + WHISPER_LOG_INFO("%s: prompt time = %8.2f ms / %5d runs (%8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_prompt_us, n_prompt, 1e-3f * ctx->state->t_prompt_us / n_prompt); + } + WHISPER_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void whisper_reset_timings(struct whisper_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_batchd_us = 0; + ctx->state->t_prompt_us = 0; + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_batchd = 0; + ctx->state->n_prompt = 0; + } +} + +static int whisper_has_coreml(void) { +#ifdef WHISPER_USE_COREML + return 1; +#else + return 0; +#endif +} + +static int whisper_has_openvino(void) { +#ifdef WHISPER_USE_OPENVINO + return 1; +#else + return 0; +#endif +} + +const char * whisper_print_system_info(void) { + static std::string s; + + s = ""; + s += "AVX = " + std::to_string(ggml_cpu_has_avx()) + " | "; + s += "AVX2 = " + std::to_string(ggml_cpu_has_avx2()) + " | "; + s += "AVX512 = " + std::to_string(ggml_cpu_has_avx512()) + " | "; + s += "FMA = " + std::to_string(ggml_cpu_has_fma()) + " | "; + s += "NEON = " + std::to_string(ggml_cpu_has_neon()) + " | "; + s += "ARM_FMA = " + std::to_string(ggml_cpu_has_arm_fma()) + " | "; + s += "METAL = " + std::to_string(ggml_cpu_has_metal()) + " | "; + s += "F16C = " + std::to_string(ggml_cpu_has_f16c()) + " | "; + s += "FP16_VA = " + std::to_string(ggml_cpu_has_fp16_va()) + " | "; + s += "WASM_SIMD = " + std::to_string(ggml_cpu_has_wasm_simd()) + " | "; + s += "BLAS = " + std::to_string(ggml_cpu_has_blas()) + " | "; + s += "SSE3 = " + std::to_string(ggml_cpu_has_sse3()) + " | "; + s += "SSSE3 = " + std::to_string(ggml_cpu_has_ssse3()) + " | "; + s += "VSX = " + std::to_string(ggml_cpu_has_vsx()) + " | "; + s += "CUDA = " + std::to_string(ggml_cpu_has_cuda()) + " | "; + s += "COREML = " + std::to_string(whisper_has_coreml()) + " | "; + s += "OPENVINO = " + std::to_string(whisper_has_openvino()) ; + + return s.c_str(); +} + +////////////////////////////////// +// Grammar - ported from llama.cpp +////////////////////////////////// + +// Decodes a UTF-8 string which may end in an incomplete sequence. Adds a terminating 0 for use as +// pointer. If an invalid sequence is encountered, returns `whisper_partial_utf8.n_remain == -1`. +std::pair, whisper_partial_utf8> decode_utf8( + const char * src, + whisper_partial_utf8 partial_start) { + static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; + const char * pos = src; + std::vector code_points; + uint32_t value = partial_start.value; + int n_remain = partial_start.n_remain; + + // continue previous decode, if applicable + while (*pos != 0 && n_remain > 0) { + uint8_t next_byte = static_cast(*pos); + if ((next_byte >> 6) != 2) { + // invalid sequence, abort + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, -1 }); + } + value = (value << 6) + (next_byte & 0x3F); + ++pos; + --n_remain; + } + + if (partial_start.n_remain > 0 && n_remain == 0) { + code_points.push_back(value); + } + + // decode any subsequent utf-8 sequences, which may end in an incomplete one + while (*pos != 0) { + uint8_t first_byte = static_cast(*pos); + uint8_t highbits = first_byte >> 4; + n_remain = lookup[highbits] - 1; + + if (n_remain < 0) { + // invalid sequence, abort + code_points.clear(); + code_points.push_back(0); + return std::make_pair(std::move(code_points), whisper_partial_utf8{ 0, n_remain }); + } + + uint8_t mask = (1 << (7 - n_remain)) - 1; + value = first_byte & mask; + ++pos; + while (*pos != 0 && n_remain > 0) { + value = (value << 6) + (static_cast(*pos) & 0x3F); + ++pos; + --n_remain; + } + if (n_remain == 0) { + code_points.push_back(value); + } + } + code_points.push_back(0); + + return std::make_pair(std::move(code_points), whisper_partial_utf8{ value, n_remain }); +} + +// returns true iff pos points to the end of one of the definitions of a rule +static bool whisper_grammar_is_end_of_sequence(const whisper_grammar_element * pos) { + switch (pos->type) { + case WHISPER_GRETYPE_END: return true; // NOLINT + case WHISPER_GRETYPE_ALT: return true; // NOLINT + default: return false; + } +} + +// returns true iff chr satisfies the char range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static std::pair whisper_grammar_match_char( + const whisper_grammar_element * pos, + const uint32_t chr) { + + bool found = false; + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); // NOLINT + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + found = found || (pos->value <= chr && chr <= pos[1].value); + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + found = found || pos->value == chr; + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return std::make_pair(found == is_positive_char, pos); +} + +// returns true iff some continuation of the given partial UTF-8 sequence could satisfy the char +// range at pos (regular or inverse range) +// asserts that pos is pointing to a char range element +static bool whisper_grammar_match_partial_char( + const whisper_grammar_element * pos, + const whisper_partial_utf8 partial_utf8) { + + bool is_positive_char = pos->type == WHISPER_GRETYPE_CHAR; + WHISPER_ASSERT(is_positive_char || pos->type == WHISPER_GRETYPE_CHAR_NOT); + + uint32_t partial_value = partial_utf8.value; + int n_remain = partial_utf8.n_remain; + + // invalid sequence or 7-bit char split across 2 bytes (overlong) + if (n_remain < 0 || (n_remain == 1 && partial_value < 2)) { + return false; + } + + // range of possible code points this partial UTF-8 sequence could complete to + uint32_t low = partial_value << (n_remain * 6); + uint32_t high = low | ((1 << (n_remain * 6)) - 1); + + if (low == 0) { + if (n_remain == 2) { + low = 1 << 11; + } else if (n_remain == 3) { + low = 1 << 16; + } + } + + do { + if (pos[1].type == WHISPER_GRETYPE_CHAR_RNG_UPPER) { + // inclusive range, e.g. [a-z] + if (pos->value <= high && low <= pos[1].value) { + return is_positive_char; + } + pos += 2; + } else { + // exact char match, e.g. [a] or "a" + if (low <= pos->value && pos->value <= high) { + return is_positive_char; + } + pos += 1; + } + } while (pos->type == WHISPER_GRETYPE_CHAR_ALT); + + return !is_positive_char; +} + + +// transforms a grammar pushdown stack into N possible stacks, all ending +// at a character range (terminal element) +static void whisper_grammar_advance_stack( + const std::vector> & rules, + const std::vector & stack, + std::vector> & new_stacks) { + + if (stack.empty()) { + new_stacks.push_back(stack); + return; + } + + const whisper_grammar_element * pos = stack.back(); + + switch (pos->type) { + case WHISPER_GRETYPE_RULE_REF: { + const size_t rule_id = static_cast(pos->value); + const whisper_grammar_element * subpos = rules[rule_id].data(); + do { + // init new stack without the top (pos) + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos + 1)) { + // if this rule ref is followed by another element, add that to stack + new_stack.push_back(pos + 1); + } + if (!whisper_grammar_is_end_of_sequence(subpos)) { + // if alternate is nonempty, add to stack + new_stack.push_back(subpos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + while (!whisper_grammar_is_end_of_sequence(subpos)) { + // scan to end of alternate def + subpos++; + } + if (subpos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + subpos++; + } else { + break; + } + } while (true); + break; + } + case WHISPER_GRETYPE_CHAR: + case WHISPER_GRETYPE_CHAR_NOT: + new_stacks.push_back(stack); + break; + default: + // end of alternate (WHISPER_GRETYPE_END, WHISPER_GRETYPE_ALT) or middle of char range + // (WHISPER_GRETYPE_CHAR_ALT, WHISPER_GRETYPE_CHAR_RNG_UPPER); stack should never be left on + // those + WHISPER_ASSERT(false); + } +} + +// takes a set of possible pushdown stacks on a grammar, which are required to +// be positioned at a character range (see `whisper_grammar_advance_stack`), and +// produces the N possible stacks if the given char is accepted at those +// positions +static std::vector> whisper_grammar_accept( + const std::vector> & rules, + const std::vector> & stacks, + const uint32_t chr) { + + std::vector> new_stacks; + + for (const auto & stack : stacks) { + if (stack.empty()) { + continue; + } + + auto match = whisper_grammar_match_char(stack.back(), chr); + if (match.first) { + const whisper_grammar_element * pos = match.second; + + // update top of stack to next element, if any + std::vector new_stack(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(pos)) { + new_stack.push_back(pos); + } + whisper_grammar_advance_stack(rules, new_stack, new_stacks); + } + } + + return new_stacks; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates); + +static std::vector whisper_grammar_reject_candidates_for_stack( + const std::vector> & rules, + const std::vector & stack, + const std::vector & candidates) { + + std::vector rejects; + + if (stack.empty()) { + for (auto tok : candidates) { + if (*tok.code_points != 0 || tok.partial_utf8.n_remain != 0) { + rejects.push_back(tok); + } + } + return rejects; + } + + const whisper_grammar_element * stack_pos = stack.back(); + + std::vector next_candidates; + for (auto tok : candidates) { + if (*tok.code_points == 0) { + // reached end of full codepoints in token, reject iff it ended in a partial sequence + // that cannot satisfy this position in grammar + if (tok.partial_utf8.n_remain != 0 && !whisper_grammar_match_partial_char(stack_pos, tok.partial_utf8)) { + rejects.push_back(tok); + } + } else if (whisper_grammar_match_char(stack_pos, *tok.code_points).first) { + next_candidates.push_back({ tok.id, tok.code_points + 1, tok.partial_utf8 }); + } else { + rejects.push_back(tok); + } + } + + const auto * stack_pos_after = whisper_grammar_match_char(stack_pos, 0).second; + + // update top of stack to next element, if any + std::vector stack_after(stack.begin(), stack.end() - 1); + if (!whisper_grammar_is_end_of_sequence(stack_pos_after)) { + stack_after.push_back(stack_pos_after); + } + std::vector> next_stacks; + whisper_grammar_advance_stack(rules, stack_after, next_stacks); + + auto next_rejects = whisper_grammar_reject_candidates(rules, next_stacks, next_candidates); + for (auto tok : next_rejects) { + rejects.push_back({ tok.id, tok.code_points - 1, tok.partial_utf8 }); + } + + return rejects; +} + +static std::vector whisper_grammar_reject_candidates( + const std::vector> & rules, + const std::vector> & stacks, + const std::vector & candidates) { + if (candidates.empty() || stacks.empty()) { + return std::vector(); + } + + auto rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks.front(), candidates); + + for (size_t i = 1, size = stacks.size(); i < size; ++i) { + rejects = whisper_grammar_reject_candidates_for_stack(rules, stacks[i], rejects); + } + return rejects; +} + +static struct whisper_grammar whisper_grammar_init( + const whisper_grammar_element ** rules, + size_t n_rules, + size_t i_start_rule) { + const whisper_grammar_element * pos; + + // copy rule definitions into vectors + std::vector> vec_rules(n_rules); + for (size_t i = 0; i < n_rules; i++) { + for (pos = rules[i]; pos->type != WHISPER_GRETYPE_END; pos++) { + vec_rules[i].push_back(*pos); + } + vec_rules[i].push_back({WHISPER_GRETYPE_END, 0}); + } + + // loop over alternates of start rule to build initial stacks + std::vector> stacks; + pos = rules[i_start_rule]; + do { + std::vector stack; + if (!whisper_grammar_is_end_of_sequence(pos)) { + // if alternate is nonempty, add to stack + stack.push_back(pos); + } + whisper_grammar_advance_stack(vec_rules, stack, stacks); + while (!whisper_grammar_is_end_of_sequence(pos)) { + // scan to end of alternate def + pos++; + } + if (pos->type == WHISPER_GRETYPE_ALT) { + // there's another alternate def of this rule to process + pos++; + } else { + break; + } + } while (true); + + return { std::move(vec_rules), std::move(stacks), {} }; +} + +static void whisper_suppress_invalid_grammar( + whisper_context & ctx, + const whisper_full_params & params, + std::vector & logits, + const whisper_grammar & grammar) { + + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //bool allow_eot = false; + //for (const auto & stack : grammar.stacks) { + // if (stack.empty()) { + // allow_eot = true; + // break; + // } + //} + + const whisper_token eot = whisper_token_eot(&ctx); + + std::vector, whisper_partial_utf8>> candidates_decoded; + std::vector candidates_grammar; + + for (whisper_token id = 0; id < eot; ++id) { + const std::string & text = ctx.vocab.id_to_token[id]; + if (!text.empty()) { + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar.partial_utf8)); + candidates_grammar.push_back({ id, candidates_decoded.back().first.data(), candidates_decoded.back().second }); + } + } + + const auto rejects = whisper_grammar_reject_candidates(grammar.rules, grammar.stacks, candidates_grammar); + + for (const auto & reject : rejects) { + logits[reject.id] -= params.grammar_penalty; + } + + // when the grammar allows a continuation, we penalize the end-of-text token + //if (!allow_eot) { + // logits[eot] -= params.grammar_penalty; + //} + //fprintf(stderr, "Allowed: (%zu tokens)\n", size - rejects.size()); +} + +static void whisper_grammar_accept_token(whisper_context & ctx, whisper_grammar & grammar, whisper_token token) { + if (grammar.rules.empty() || grammar.stacks.empty()) { + return; + } + + //fprintf(stderr, "Accept: '%s'\n", ctx.vocab.id_to_token[token].c_str()); + + const std::string & text = ctx.vocab.id_to_token[token]; + + if (text.rfind("[_", 0) == 0) { + // fprintf(stderr, " (skipped)\n"); + return; + } + // fprintf(stderr, "\n"); + + // Note terminating 0 in decoded string + const auto decoded = decode_utf8(text.c_str(), grammar.partial_utf8); + const auto & code_points = decoded.first; + for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { + grammar.stacks = whisper_grammar_accept(grammar.rules, grammar.stacks, *it); + } + grammar.partial_utf8 = decoded.second; +} + +////////////// +// END grammar +////////////// + +//////////////////////////////////////////////////////////////////////////// + +struct whisper_context_params * whisper_context_default_params_by_ref() { + struct whisper_context_params params = whisper_context_default_params(); + + struct whisper_context_params* result = new whisper_context_params(); + *result = params; + return result; +} + +struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy) { + struct whisper_full_params params = whisper_full_default_params(strategy); + + struct whisper_full_params* result = new whisper_full_params(); + *result = params; + return result; +} + +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { + struct whisper_full_params result = { + /*.strategy =*/ strategy, + + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.n_max_text_ctx =*/ 16384, + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + + /*.translate =*/ false, + /*.no_context =*/ true, + /*.no_timestamps =*/ false, + /*.single_segment =*/ false, + /*.print_special =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, + + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.split_on_word =*/ false, + /*.max_tokens =*/ 0, + + /*.speed_up =*/ false, + /*.debug_mode =*/ false, + /*.audio_ctx =*/ 0, + + /*.tdrz_enable =*/ false, + + /* suppress_regex =*/ nullptr, + + /*.initial_prompt =*/ nullptr, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + + /*.language =*/ "en", + /*.detect_language =*/ false, + + /*.suppress_blank =*/ true, + /*.suppress_non_speech_tokens =*/ false, + + /*.temperature =*/ 0.0f, + /*.max_initial_ts =*/ 1.0f, + /*.length_penalty =*/ -1.0f, + + /*.temperature_inc =*/ 0.2f, + /*.entropy_thold =*/ 2.4f, + /*.logprob_thold =*/ -1.0f, + /*.no_speech_thold =*/ 0.6f, + + /*.greedy =*/ { + /*.best_of =*/ -1, + }, + + /*.beam_search =*/ { + /*.beam_size =*/ -1, + + /*.patience =*/ -1.0f, + }, + + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + + /*.logits_filter_callback =*/ nullptr, + /*.logits_filter_callback_user_data =*/ nullptr, + + /*.grammar_rules =*/ nullptr, + /*.n_grammar_rules =*/ 0, + /*.i_start_rule =*/ 0, + /*.grammar_penalty =*/ 100.0f, + }; + + switch (strategy) { + case WHISPER_SAMPLING_GREEDY: + { + result.greedy = { + /*.best_of =*/ 5, + }; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + result.beam_search = { + /*.beam_size =*/ 5, + + /*.patience =*/ -1.0f, + }; + } break; + } + + return result; +} + +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + struct whisper_state & state, + int i_segment, + float thold_pt, + float thold_ptsum); + +static inline bool should_split_on_word(const char * txt, bool split_on_word) { + if (!split_on_word) return true; + + return txt[0] == ' '; +} + +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context & ctx, struct whisper_state & state, int max_len, bool split_on_word) { + auto segment = state.result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(&ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(&ctx, token.id); + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0 && should_split_on_word(txt, split_on_word)) { + state.result_all.back().text = std::move(text); + state.result_all.back().t1 = token.t0; + state.result_all.back().tokens.resize(i); + state.result_all.back().speaker_turn_next = false; + + state.result_all.push_back({}); + state.result_all.back().t0 = token.t0; + state.result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + state.result_all.back().tokens.insert( + state.result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + state.result_all.back().speaker_turn_next = segment.speaker_turn_next; + + acc = 0; + text = ""; + + segment = state.result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + state.result_all.back().text = std::move(text); + + return res; +} + +static const std::vector non_speech_tokens = { + "\"", "#", "(", ")", "*", "+", "/", ":", ";", "<", "=", ">", "@", "[", "\\", "]", "^", + "_", "`", "{", "|", "}", "~", "「", "」", "『", "』", "<<", ">>", "<<<", ">>>", "--", + "---", "-(", "-[", "('", "(\"", "((", "))", "(((", ")))", "[[", "]]", "{{", "}}", "♪♪", + "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" +}; + +// process the logits for the selected decoder +// - applies logit filters +// - computes logprobs and probs +// TODO: optimize +static void whisper_process_logits( + struct whisper_context & ctx, + struct whisper_state & state, + struct whisper_decoder & decoder, + const struct whisper_full_params params, + float temperature) { + const auto & vocab = ctx.vocab; + const auto & tokens_cur = decoder.sequence.tokens; + + const bool is_initial = tokens_cur.size() == 0; + const int n_logits = vocab.id_to_token.size(); + + WHISPER_ASSERT(n_logits == ctx.vocab.n_vocab); + + // extract the logits for the last token + // we will be mutating, and therefore we don't want to use the ctx.logits buffer directly + auto & probs = decoder.probs; + auto & logits = decoder.logits; + auto & logprobs = decoder.logprobs; + { + logits.resize(n_logits); + memcpy(logits.data(), state.logits.data() + decoder.i_batch*n_logits, n_logits*sizeof(float)); + + if (temperature > 0.0f) { + for (int i = 0; i < n_logits; i++) { + logits[i] /= temperature; + } + } + + // will be populated a bit later + probs.resize(n_logits); + logprobs.resize(n_logits); + } + + // apply logit filters here + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L480-L493 + { + // suppress blank + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L388-L390 + if (params.suppress_blank) { + if (is_initial) { + logits[vocab.token_eot] = -INFINITY; + logits[vocab.token_to_id.at(" ")] = -INFINITY; + } + } + + // suppress <|notimestamps|> token + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L410-L412 + logits[vocab.token_not] = -INFINITY; + if (params.no_timestamps) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // suppress sot and nosp tokens + logits[vocab.token_sot] = -INFINITY; + logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + + // [TDRZ] when tinydiarize is disabled, suppress solm token + if (params.tdrz_enable == false) { + logits[vocab.token_solm] = -INFINITY; + } + + // suppress task tokens + logits[vocab.token_translate] = -INFINITY; + logits[vocab.token_transcribe] = -INFINITY; + logits[vocab.token_prev] = -INFINITY; + + // suppress lang tokens + for (size_t i = 0; i < g_lang.size(); ++i) { + logits[whisper_token_lang(&ctx, i)] = -INFINITY; + } + + // suppress prev token + logits[vocab.token_prev] = -INFINITY; + + if (params.logits_filter_callback) { + params.logits_filter_callback(&ctx, &state, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data); + } + + // suppress any tokens matching a regular expression + // ref: https://github.com/openai/whisper/discussions/1041 + if (params.suppress_regex != nullptr) { + std::regex re(params.suppress_regex); + for (std::pair token_id : vocab.token_to_id) { + if (std::regex_match(token_id.first, re)) { + logits[token_id.second] = -INFINITY; + } + } + } + + // suppress non-speech tokens + // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + if (params.suppress_non_speech_tokens) { + for (const std::string & token : non_speech_tokens) { + const std::string suppress_tokens[] = {token, " " + token}; + for (const std::string & suppress_token : suppress_tokens) { + if (vocab.token_to_id.find(suppress_token) != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(suppress_token)] = -INFINITY; + } + } + } + + // allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word + if (vocab.token_to_id.find(" -") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" -")] = -INFINITY; + } + if (vocab.token_to_id.find(" '") != vocab.token_to_id.end()) { + logits[vocab.token_to_id.at(" '")] = -INFINITY; + } + } + + // timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + // https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L414-L424 + { + const bool last_was_timestamp = tokens_cur.size() > 0 && tokens_cur.back().id >= vocab.token_beg; + const bool penultimate_was_timestamp = tokens_cur.size() < 2 || tokens_cur[tokens_cur.size() - 2].id >= vocab.token_beg; + + //WHISPER_LOG_INFO("last_was_timestamp=%d penultimate_was_timestamp=%d\n", last_was_timestamp, penultimate_was_timestamp); + + if (last_was_timestamp) { + if (penultimate_was_timestamp) { + for (int i = vocab.token_beg; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } else { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + } + } + + // the initial timestamp cannot be larger than max_initial_ts + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L426-L429 + if (is_initial && params.max_initial_ts > 0.0f) { + const float precision = float(WHISPER_CHUNK_SIZE)/ctx.model.hparams.n_audio_ctx; + const int tid0 = std::round(params.max_initial_ts/precision); + + for (int i = vocab.token_beg + tid0 + 1; i < n_logits; ++i) { + logits[i] = -INFINITY; + } + } + + // condition timestamp tokens to be increasing + // ref: https://github.com/openai/whisper/pull/831#issuecomment-1385910556 + if (decoder.has_ts) { + const int tid0 = decoder.seek_delta/2; + + for (int i = vocab.token_beg; i < vocab.token_beg + tid0; ++i) { + logits[i] = -INFINITY; + } + } + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + + // if sum of probability over timestamps is above any other token, sample timestamp + // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 + { + // logsumexp over timestamps + float timestamp_logprob = -INFINITY; + { + float logsumexp = 0.0f; + const float logprob_max = *std::max_element(logprobs.begin() + vocab.token_beg, logprobs.end()); + for (int i = vocab.token_beg; i < n_logits; ++i) { + if (logprobs[i] > -INFINITY) { + logsumexp += expf(logprobs[i] - logprob_max); + } + } + if (logsumexp > 0.0f) { + timestamp_logprob = logf(logsumexp) + logprob_max; + } + } + + const float max_text_token_logprob = *std::max_element(logprobs.begin(), logprobs.begin() + vocab.token_beg); + + //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); + + if (timestamp_logprob > max_text_token_logprob) { + for (int i = 0; i < vocab.token_beg; ++i) { + logits[i] = -INFINITY; + logprobs[i] = -INFINITY; + } + } else { + if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } + } + } + } + } + } + + // compute probs + { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } + } + +#if 0 + // print first 100 logits - token string : logit + //for (int i = 0; i < 10; i++) { + // const auto token = vocab.id_to_token.at(i); + // const auto prob = probs[i]; + // const auto logit = logits[i]; + // const auto logprob = logprobs[i]; + // printf("%16s : prob=%9.5f logit=%9.5f logprob=%9.5f\n", token.c_str(), prob, logit, logprob); + //} + + // print sorted + { + std::vector> pairs; + + for (int i = 0; i < n_logits; ++i) { + pairs.push_back(std::make_pair(probs[i], i)); + } + + std::sort(pairs.begin(), pairs.end(), [](const std::pair& a, const std::pair& b) { + return a.first > b.first; + }); + + for (int i = 0; i < 10; i++) { + const auto token = vocab.id_to_token.at(pairs[i].second); + const auto prob = pairs[i].first; + const auto logit = logits[pairs[i].second]; + const auto logprob = logprobs[pairs[i].second]; + printf("%16s : id=%6d prob=%9.5f logit=%9.5f logprob=%9.5f '%s'\n", token.c_str(), pairs[i].second, prob, logit, logprob, token.c_str()); + } + + printf("----------------\n"); + } + + // "And", "and", " And", " and" + //printf("logits[\"and\"] = %f\n", logits[vocab.token_to_id.at("and")]); + //printf("logits[\"And\"] = %f\n", logits[vocab.token_to_id.at("And")]); + //printf("logits[\" and\"] = %f\n", logits[vocab.token_to_id.at(" and")]); + //printf("logits[\" And\"] = %f\n", logits[vocab.token_to_id.at(" And")]); + //printf("logits[\" so\"] = %f\n", logits[vocab.token_to_id.at(" so")]); + + //printf("logprobs[\"and\"] = %f\n", logprobs[vocab.token_to_id.at("and")]); + //printf("logprobs[\"And\"] = %f\n", logprobs[vocab.token_to_id.at("And")]); + //printf("logprobs[\" and\"] = %f\n", logprobs[vocab.token_to_id.at(" and")]); + //printf("logprobs[\" And\"] = %f\n", logprobs[vocab.token_to_id.at(" And")]); + //printf("logprobs[\" so\"] = %f\n", logprobs[vocab.token_to_id.at(" so")]); + + //printf("probs[\"and\"] = %f\n", probs[vocab.token_to_id.at("and")]); + //printf("probs[\"And\"] = %f\n", probs[vocab.token_to_id.at("And")]); + //printf("probs[\" and\"] = %f\n", probs[vocab.token_to_id.at(" and")]); + //printf("probs[\" And\"] = %f\n", probs[vocab.token_to_id.at(" And")]); + //printf("probs[\" so\"] = %f\n", probs[vocab.token_to_id.at(" so")]); +#endif +} + +static bool whisper_sequence_tokens_equal(const whisper_sequence & a, const whisper_sequence & b) { + if (a.tokens.size() != b.tokens.size()) { + return false; + } + // sequences are more likely to diverge at the end + for (int i = a.tokens.size() - 1; i >= 0; i--) { + if (a.tokens[i].id != b.tokens[i].id) { + return false; + } + } + return true; +} + +static whisper_token_data whisper_sample_token( + whisper_context & ctx, + const whisper_decoder & decoder, + bool best) { + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, 0.0f, -1, -1, -1, 0.0f, + }; + + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + result.tid = i; + } + } + + result.pt = max_ts/(sum_ts + 1e-10); + result.ptsum = sum_ts; + } + + if (best) { + for (int i = 0; i < n_logits; ++i) { + if (result.p < probs[i]) { + result.id = i; + result.p = probs[i]; + result.plog = logprobs[i]; + } + } + } else { + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + result.id = dist(decoder.rng); + result.p = probs[result.id]; + result.plog = logprobs[result.id]; + } + + if (result.id >= vocab.token_beg) { + result.tid = result.id; + result.pt = result.p; + } + + return result; +} + +static std::vector whisper_sample_token_topk( + whisper_context & ctx, + whisper_decoder & decoder, + int k) { + const auto & vocab = ctx.vocab; + + const auto & probs = decoder.probs; + const auto & logits = decoder.logits; + const auto & logprobs = decoder.logprobs; + + const int n_logits = vocab.n_vocab; + + auto & logits_id = decoder.logits_id; + + logits_id.resize(n_logits); + for (int i = 0; i < n_logits; ++i) { + logits_id[i].first = logits[i]; + logits_id[i].second = i; + } + + { + using pair_type = std::remove_reference::type::value_type; + std::partial_sort( + logits_id.begin(), + logits_id.begin() + k, logits_id.end(), + [](const pair_type & a, const pair_type & b) { + return a.first > b.first; + }); + } + + std::vector result; + result.reserve(k); + + whisper_token tid = vocab.token_beg; + + float pt = 0.0; + float ptsum = 0.0; + + { + double sum_ts = 0.0; + double max_ts = 0.0; + + for (int i = vocab.token_beg; i < n_logits; i++) { + if (probs[i] == -INFINITY) { + continue; + } + + sum_ts += probs[i]; + if (max_ts < probs[i]) { + max_ts = probs[i]; + tid = i; + } + } + + pt = max_ts/(sum_ts + 1e-10); + ptsum = sum_ts; + } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + + for (int i = 0; i < k; ++i) { + const auto id = dist(decoder.rng); + //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); + + result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, -1, 0.0f, }); + + if (result[i].id >= vocab.token_beg) { + result[i].tid = result[i].id; + result[i].pt = result[i].p; + } + } + + return result; +} + +// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L178-L192 +static void whisper_sequence_score( + const struct whisper_full_params & params, + whisper_sequence & sequence) { + if (sequence.result_len == 0) { + return; + } + + double result = 0.0f; + + for (int i = 0; i < sequence.result_len; ++i) { + result += sequence.tokens[i].plog; + } + + sequence.sum_logprobs = result; + sequence.avg_logprobs = result/sequence.result_len; + + double penalty = sequence.result_len; + + if (params.length_penalty > 0.0f) { + penalty = pow((5.0 + penalty)/6.0, params.length_penalty); + } + + sequence.score = result/penalty; + + // compute the entropy of the sequence of the last 32 tokens + { + const int n = 32; + + int cnt = 0; + double entropy = 0.0f; + + std::map token_counts; + for (int i = std::max(0, sequence.result_len - n); i < sequence.result_len; ++i) { + token_counts[sequence.tokens[i].id]++; + cnt++; + } + + for (const auto & kv : token_counts) { + const auto p = kv.second/(double)cnt; + entropy -= p*log(p); + + //WHISPER_LOG_DEBUG("entropy: %d %f %f, count %d\n", kv.first, p, log(p), kv.second); + } + + sequence.entropy = entropy; + } +} + +int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples) { + // clear old results + auto & result_all = state->result_all; + + result_all.clear(); + + if (n_samples > 0) { + // compute log mel spectrogram + if (params.speed_up) { + // TODO: Replace PV with more advanced algorithm + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -1; + } else { + if (whisper_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + WHISPER_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + } + + // auto-detect language if not specified + if (params.language == nullptr || strlen(params.language) == 0 || strcmp(params.language, "auto") == 0 || params.detect_language) { + std::vector probs(whisper_lang_max_id() + 1, 0.0f); + + const auto lang_id = whisper_lang_auto_detect_with_state(ctx, state, 0, params.n_threads, probs.data()); + if (lang_id < 0) { + WHISPER_LOG_ERROR("%s: failed to auto-detect language\n", __func__); + return -3; + } + state->lang_id = lang_id; + params.language = whisper_lang_str(lang_id); + + WHISPER_LOG_INFO("%s: auto-detected language: %s (p = %f)\n", __func__, params.language, probs[whisper_lang_id(params.language)]); + if (params.detect_language) { + return 0; + } + } + + if (params.token_timestamps) { + state->t_beg = 0; + state->t_last = 0; + state->tid_last = 0; + if (n_samples > 0) { + state->energy = get_signal_energy(samples, n_samples, 32); + } + } + + const int seek_start = params.offset_ms/10; + const int seek_end = params.duration_ms == 0 ? whisper_n_len_from_state(state) : seek_start + params.duration_ms/10; + + // if length of spectrogram is less than 1.0s (100 frames), then return + // basically don't process anything that is less than 1.0s + // see issue #39: https://github.com/ggerganov/whisper.cpp/issues/39 + if (seek_end < seek_start + (params.speed_up ? 50 : 100)) { + WHISPER_LOG_WARN("%s: input is too short - %d ms < 1000 ms. consider padding the input audio with silence\n", __func__, (seek_end - seek_start)*10); + return 0; + } + + // a set of temperatures to use + // [ t0, t0 + delta, t0 + 2*delta, ..., < 1.0f + 1e-6f ] + std::vector temperatures; + if (params.temperature_inc > 0.0f) { + for (float t = params.temperature; t < 1.0f + 1e-6f; t += params.temperature_inc) { + temperatures.push_back(t); + } + } else { + temperatures.push_back(params.temperature); + } + + // initialize the decoders + int n_decoders = 1; + + switch (params.strategy) { + case WHISPER_SAMPLING_GREEDY: + { + n_decoders = params.greedy.best_of; + } break; + case WHISPER_SAMPLING_BEAM_SEARCH: + { + n_decoders = std::max(params.greedy.best_of, params.beam_search.beam_size); + } break; + }; + + n_decoders = std::max(1, n_decoders); + + if (n_decoders > WHISPER_MAX_DECODERS) { + WHISPER_LOG_ERROR("%s: too many decoders requested (%d), max = %d\n", __func__, n_decoders, WHISPER_MAX_DECODERS); + return -4; + } + + // TAGS: WHISPER_DECODER_INIT + for (int j = 1; j < n_decoders; j++) { + auto & decoder = state->decoders[j]; + + decoder.sequence.tokens.reserve(state->decoders[0].sequence.tokens.capacity()); + + decoder.probs.resize (ctx->vocab.n_vocab); + decoder.logits.resize (ctx->vocab.n_vocab); + decoder.logprobs.resize(ctx->vocab.n_vocab); + decoder.logits_id.reserve(ctx->model.hparams.n_vocab); + + decoder.rng = std::mt19937(0); + } + + // the accumulated text context so far + auto & prompt_past = state->prompt_past; + if (params.no_context) { + prompt_past.clear(); + } + + // prepare prompt + { + std::vector prompt_tokens; + + // initial prompt + if (!params.prompt_tokens && params.initial_prompt) { + prompt_tokens.resize(1024); + int n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + if (n_needed < 0) { + prompt_tokens.resize(-n_needed); + n_needed = whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()); + } + prompt_tokens.resize(n_needed); + params.prompt_tokens = prompt_tokens.data(); + params.prompt_n_tokens = prompt_tokens.size(); + } + + // prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // parse tokens from the pointer + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + } + + // overwrite audio_ctx, max allowed is hparams.n_audio_ctx + if (params.audio_ctx > whisper_n_audio_ctx(ctx)) { + WHISPER_LOG_ERROR("%s: audio_ctx is larger than the maximum allowed (%d > %d)\n", __func__, params.audio_ctx, whisper_n_audio_ctx(ctx)); + return -5; + } + state->exp_n_audio_ctx = params.audio_ctx; + + // these tokens determine the task that will be performed + std::vector prompt_init = { whisper_token_sot(ctx), }; + + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + if (params.translate) { + prompt_init.push_back(whisper_token_translate(ctx)); + } else { + prompt_init.push_back(whisper_token_transcribe(ctx)); + } + } + + // first release distilled models require the "no_timestamps" token + { + const bool is_distil = ctx->model.hparams.n_text_layer == 2 && ctx->model.hparams.n_vocab != 51866; + if (is_distil && !params.no_timestamps) { + WHISPER_LOG_WARN("%s: using first release distilled models - forcing no_timestamps\n", __func__); + params.no_timestamps = true; + } + } + + if (params.no_timestamps) { + prompt_init.push_back(whisper_token_not(ctx)); + } + + int seek = seek_start; + + std::vector prompt; + prompt.reserve(whisper_n_text_ctx(ctx)); + + struct beam_candidate { + int decoder_idx; + int seek_delta; + + bool has_ts; + + whisper_sequence sequence; + whisper_grammar grammar; + }; + + std::vector> bc_per_dec(n_decoders); + std::vector beam_candidates; + + // main loop + while (true) { + if (params.progress_callback) { + const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); + + params.progress_callback( + ctx, state, progress_cur, params.progress_callback_user_data); + } + + // if only 1 second left, then stop + if (seek + 100 >= seek_end) { + break; + } + + if (params.encoder_begin_callback) { + if (params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data) == false) { + WHISPER_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + break; + } + } + + // encode audio features starting at offset seek + if (!whisper_encode_internal(*ctx, *state, seek, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + WHISPER_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + // if there is a very short audio segment left to process, we remove any past prompt since it tends + // to confuse the decoder and often make it repeat or hallucinate stuff + if (seek > seek_start && seek + 500 >= seek_end) { + prompt_past.clear(); + } + + int best_decoder_id = 0; + + for (int it = 0; it < (int) temperatures.size(); ++it) { + const float t_cur = temperatures[it]; + + int n_decoders_cur = 1; + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + if (t_cur > 0.0f) { + n_decoders_cur = params.greedy.best_of; + } else { + n_decoders_cur = params.beam_search.beam_size; + } + } break; + }; + + n_decoders_cur = std::max(1, n_decoders_cur); + + WHISPER_LOG_DEBUG("\n%s: strategy = %d, decoding with %d decoders, temperature = %.2f\n", __func__, params.strategy, n_decoders_cur, t_cur); + + // TAGS: WHISPER_DECODER_INIT + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + decoder.sequence.tokens.clear(); + decoder.sequence.result_len = 0; + decoder.sequence.sum_logprobs_all = 0.0; + decoder.sequence.sum_logprobs = -INFINITY; + decoder.sequence.avg_logprobs = -INFINITY; + decoder.sequence.entropy = 0.0; + decoder.sequence.score = -INFINITY; + + decoder.seek_delta = 100*WHISPER_CHUNK_SIZE; + + decoder.failed = false; + decoder.completed = false; + decoder.has_ts = false; + + if (params.grammar_rules != nullptr) { + decoder.grammar = whisper_grammar_init(params.grammar_rules, params.n_grammar_rules, params.i_start_rule); + } else { + decoder.grammar = {}; + } + } + + // init prompt and kv cache for the current iteration + // TODO: do not recompute the prompt if it is the same as previous time + { + prompt.clear(); + + // if we have already generated some text, use it as a prompt to condition the next generation + if (!prompt_past.empty() && t_cur < 0.5f && params.n_max_text_ctx > 0) { + int n_take = std::min(std::min(params.n_max_text_ctx, whisper_n_text_ctx(ctx)/2), int(prompt_past.size())); + + prompt = { whisper_token_prev(ctx) }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); + } + + // init new transcription with sot, language (opt) and task tokens + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + + // print the prompt + WHISPER_LOG_DEBUG("\n\n"); + for (int i = 0; i < (int) prompt.size(); i++) { + WHISPER_LOG_DEBUG("%s: prompt[%d] = %s\n", __func__, i, ctx->vocab.id_to_token.at(prompt[i]).c_str()); + } + WHISPER_LOG_DEBUG("\n\n"); + + whisper_kv_cache_clear(state->kv_self); + + whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + { + const int64_t t_start_sample_us = ggml_time_us(); + + state->decoders[0].i_batch = prompt.size() - 1; + + whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur); + + for (int j = 1; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + whisper_kv_cache_seq_cp(state->kv_self, 0, j, -1, -1); + + memcpy(decoder.probs.data(), state->decoders[0].probs.data(), decoder.probs.size()*sizeof(decoder.probs[0])); + memcpy(decoder.logits.data(), state->decoders[0].logits.data(), decoder.logits.size()*sizeof(decoder.logits[0])); + memcpy(decoder.logprobs.data(), state->decoders[0].logprobs.data(), decoder.logprobs.size()*sizeof(decoder.logprobs[0])); + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + for (int i = 0, n_max = whisper_n_text_ctx(ctx)/2 - 4; i < n_max; ++i) { + const int64_t t_start_sample_us = ggml_time_us(); + + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + for (auto & bc : bc_per_dec) { + bc.clear(); + } + } + + // sampling + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); + + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) { + break; + } + + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } break; + }; + } + }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } + } + + beam_candidates.clear(); + for (const auto & bc : bc_per_dec) { + beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); + + if (!bc.empty()) { + state->n_sample += 1; + } + } + + // for beam-search, choose the top candidates and update the KV caches + if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { + std::sort( + beam_candidates.begin(), + beam_candidates.end(), + [](const beam_candidate & a, const beam_candidate & b) { + if (a.sequence.sum_logprobs_all != b.sequence.sum_logprobs_all) { + return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all; + } + return a.decoder_idx < b.decoder_idx; + }); + + uint32_t cur_c = 0; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + if (cur_c >= beam_candidates.size()) { + cur_c = 0; + } + + auto & cur = beam_candidates[cur_c++]; + + while (beam_candidates.size() > cur_c && whisper_sequence_tokens_equal(beam_candidates[cur_c].sequence, cur.sequence) && i > 0) { + ++cur_c; + } + + decoder.seek_delta = cur.seek_delta; + decoder.has_ts = cur.has_ts; + decoder.sequence = cur.sequence; + decoder.grammar = cur.grammar; + + whisper_kv_cache_seq_cp(state->kv_self, cur.decoder_idx, WHISPER_MAX_DECODERS + j, -1, -1); + + WHISPER_LOG_DEBUG("%s: beam search: decoder %d: from decoder %d: token = %10s, plog = %8.5f, sum_logprobs = %8.5f\n", + __func__, j, cur.decoder_idx, ctx->vocab.id_to_token.at(decoder.sequence.tokens.back().id).c_str(), decoder.sequence.tokens.back().plog, decoder.sequence.sum_logprobs_all); + } + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + whisper_kv_cache_seq_rm(state->kv_self, j, -1, -1); + whisper_kv_cache_seq_cp(state->kv_self, WHISPER_MAX_DECODERS + j, j, -1, -1); + whisper_kv_cache_seq_rm(state->kv_self, WHISPER_MAX_DECODERS + j, -1, -1); + } + } + + // update the decoder state + // - check if the sequence is completed + // - check if the sequence is failed + // - update sliding window based on timestamp tokens + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + auto & has_ts = decoder.has_ts; + auto & failed = decoder.failed; + auto & completed = decoder.completed; + auto & seek_delta = decoder.seek_delta; + auto & result_len = decoder.sequence.result_len; + + { + const auto & token = decoder.sequence.tokens.back(); + + // timestamp token - update sliding window + if (token.id > whisper_token_beg(ctx)) { + const int seek_delta_new = 2*(token.id - whisper_token_beg(ctx)); + + // do not allow to go back in time + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to seek_delta (%d > %d)\n", __func__, j, seek_delta, seek_delta_new); + failed = true; // TODO: maybe this is not a failure ? + continue; + } + + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + + whisper_grammar_accept_token(*ctx, decoder.grammar, token.id); + +#ifdef WHISPER_DEBUG + { + const auto tt = token.pt > 0.10 ? ctx->vocab.id_to_token.at(token.tid) : "[?]"; + WHISPER_LOG_DEBUG("%s: id = %3d, decoder = %d, token = %6d, p = %6.3f, ts = %10s, %6.3f, result_len = %4d '%s'\n", + __func__, i, j, token.id, token.p, tt.c_str(), token.pt, result_len, ctx->vocab.id_to_token.at(token.id).c_str()); + } +#endif + + // end of segment + if (token.id == whisper_token_eot(ctx) || // end of text token + (params.max_tokens > 0 && i >= params.max_tokens) || // max tokens per segment reached + (has_ts && seek + seek_delta + 100 >= seek_end) // end of audio reached + ) { + if (result_len == 0 && !params.no_timestamps) { + if (seek + seek_delta + 100 >= seek_end) { + result_len = i + 1; + } else { + WHISPER_LOG_DEBUG("%s: decoder %d failed (result_len = 0)\n", __func__, j); + failed = true; + continue; + } + } + + if (params.single_segment || params.no_timestamps) { + result_len = i + 1; + seek_delta = 100*WHISPER_CHUNK_SIZE; + } + + WHISPER_LOG_DEBUG("%s: decoder %d completed\n", __func__, j); + completed = true; + continue; + } + + // TESTS: if no tensors are loaded, it means we are running tests + if (ctx->model.n_loaded == 0) { + seek_delta = 100*WHISPER_CHUNK_SIZE; + completed = true; + continue; + } + } + + // sometimes, the decoding can get stuck in a repetition loop + // this is an attempt to mitigate such cases - we flag the decoding as failed and use a fallback strategy + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100*WHISPER_CHUNK_SIZE/2)) { + WHISPER_LOG_DEBUG("%s: decoder %d: failed due to repetition loop\n", __func__, j); + failed = true; + continue; + } + } + + // check if all decoders have finished (i.e. completed or failed) + { + bool completed_all = true; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + completed_all = false; + } + + if (completed_all) { + break; + } + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + + // obtain logits for the next token + { + auto & batch = state->batch; + + batch.n_tokens = 0; + + const int n_past = prompt.size() + i; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + //WHISPER_LOG_DEBUG("%s: decoder %d: token %d, seek_delta %d\n", __func__, j, decoder.sequence.tokens.back().id, decoder.seek_delta); + + decoder.i_batch = batch.n_tokens; + + batch.token [batch.n_tokens] = decoder.sequence.tokens.back().id; + batch.pos [batch.n_tokens] = n_past; + batch.n_seq_id[batch.n_tokens] = 1; + batch.seq_id [batch.n_tokens][0] = j; + batch.logits [batch.n_tokens] = 1; + batch.n_tokens++; + } + + assert(batch.n_tokens > 0); + + if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) { + WHISPER_LOG_ERROR("%s: failed to decode\n", __func__); + return -8; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // TODO: avoid memory allocations, optimize, avoid threads? + { + std::atomic j_cur(0); + + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) { + break; + } + + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + whisper_process_logits(*ctx, *state, decoder, params, t_cur); + } + }; + + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + if (n_threads == 1) { + process(); + } else { + std::vector threads(n_threads - 1); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); + } + + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } + } + } + + state->t_sample_us += ggml_time_us() - t_start_sample_us; + } + } + + // rank the resulting sequences and select the best one + { + double best_score = -INFINITY; + + for (int j = 0; j < n_decoders_cur; ++j) { + auto & decoder = state->decoders[j]; + + if (decoder.failed) { + continue; + } + + decoder.sequence.tokens.resize(decoder.sequence.result_len); + whisper_sequence_score(params, decoder.sequence); + + WHISPER_LOG_DEBUG("%s: decoder %2d: score = %8.5f, result_len = %3d, avg_logprobs = %8.5f, entropy = %8.5f\n", + __func__, j, decoder.sequence.score, decoder.sequence.result_len, decoder.sequence.avg_logprobs, decoder.sequence.entropy); + + if (decoder.sequence.result_len > 32 && decoder.sequence.entropy < params.entropy_thold) { + WHISPER_LOG_DEBUG("%s: decoder %2d: failed due to entropy %8.5f < %8.5f\n", + __func__, j, decoder.sequence.entropy, params.entropy_thold); + + decoder.failed = true; + state->n_fail_h++; + + continue; + } + + if (best_score < decoder.sequence.score) { + best_score = decoder.sequence.score; + best_decoder_id = j; + } + } + + WHISPER_LOG_DEBUG("%s: best decoder = %d\n", __func__, best_decoder_id); + } + + bool success = true; + + // was the decoding successful for the current temperature? + // do fallback only if: + // - we are not at the last temperature + if (it != (int) temperatures.size() - 1) { + const auto & decoder = state->decoders[best_decoder_id]; + + if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { + WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold); + success = false; + state->n_fail_p++; + } + } + + if (success) { + //for (auto & token : ctx->decoders[best_decoder_id].sequence.tokens) { + // WHISPER_LOG_DEBUG("%s: token = %d, p = %6.3f, pt = %6.3f, ts = %s, str = %s\n", __func__, token.id, token.p, token.pt, ctx->vocab.id_to_token.at(token.tid).c_str(), ctx->vocab.id_to_token.at(token.id).c_str()); + //} + + break; + } + + WHISPER_LOG_DEBUG("\n%s: failed to decode with temperature = %.2f\n", __func__, t_cur); + } + + // output results through a user-provided callback + { + const auto & best_decoder = state->decoders[best_decoder_id]; + + const auto seek_delta = best_decoder.seek_delta; + const auto result_len = best_decoder.sequence.result_len; + + const auto & tokens_cur = best_decoder.sequence.tokens; + + // [EXPERIMENTAL] Token-level timestamps with DTW + const auto n_segments_before = state->result_all.size(); + + //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); + + // update prompt_past + prompt_past.clear(); + if (prompt.front() == whisper_token_prev(ctx)) { + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); + } + + for (int i = 0; i < result_len; ++i) { + prompt_past.push_back(tokens_cur[i].id); + } + + if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + int i0 = 0; + auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx)); + + std::string text; + bool speaker_turn_next = false; + + for (int i = 0; i < (int) tokens_cur.size(); i++) { + //printf("%s: %18s %6.3f %18s %6.3f\n", __func__, + // ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].p, + // ctx->vocab.id_to_token[tokens_cur[i].tid].c_str(), tokens_cur[i].pt); + + if (params.print_special || tokens_cur[i].id < whisper_token_eot(ctx)) { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + + // [TDRZ] record if speaker turn was predicted after current segment + if (params.tdrz_enable && tokens_cur[i].id == whisper_token_solm(ctx)) { + speaker_turn_next = true; + } + + if (tokens_cur[i].id > whisper_token_beg(ctx) && !params.single_segment) { + const auto t1 = seek + 2*(tokens_cur[i].tid - whisper_token_beg(ctx)); + + if (!text.empty()) { + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); + + result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); + for (int j = i0; j <= i; j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + text = ""; + while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + i++; + } + i--; + t0 = t1; + i0 = i + 1; + speaker_turn_next = false; + } + } + + if (!text.empty()) { + const auto t1 = seek + seek_delta; + + const auto tt0 = params.speed_up ? 2*t0 : t0; + const auto tt1 = params.speed_up ? 2*t1 : t1; + + if (params.print_realtime) { + if (params.print_timestamps) { + printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); + } else { + printf("%s", text.c_str()); + fflush(stdout); + } + } + + result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); + for (int j = i0; j < (int) tokens_cur.size(); j++) { + result_all.back().tokens.push_back(tokens_cur[j]); + } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + *ctx, *state, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(*ctx, *state, params.max_len, params.split_on_word); + } + } + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, n_new, params.new_segment_callback_user_data); + } + } + } + + // FIXME: will timestamp offsets be correct? + // [EXPERIMENTAL] Token-level timestamps with DTW + { + const auto n_segments = state->result_all.size() - n_segments_before; + if (ctx->params.dtw_token_timestamps && n_segments) { + const int n_frames = std::min(std::min(WHISPER_CHUNK_SIZE * 100, seek_delta), seek_end - seek); + whisper_exp_compute_token_level_timestamps_dtw( + ctx, state, params, result_all.size() - n_segments, n_segments, seek, n_frames, 7, params.n_threads); + } + } + + // update audio window + seek += seek_delta; + + WHISPER_LOG_DEBUG("seek = %d, seek_delta = %d\n", seek, seek_delta); + } + } + + return 0; +} + +int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples) { + return whisper_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors) { + if (n_processors == 1) { + return whisper_full(ctx, params, samples, n_samples); + } + int ret = 0; + + // prepare separate states for each thread + std::vector states; + + const int offset_samples = (WHISPER_SAMPLE_RATE*params.offset_ms)/1000; + const int n_samples_per_processor = (n_samples - offset_samples)/n_processors; + + // the calling thread will process the first chunk + // while the other threads will process the remaining chunks + + std::vector workers(n_processors - 1); + for (int i = 0; i < n_processors - 1; ++i) { + // create a new state for each thread + states.push_back(whisper_init_state(ctx)); + + const int start_samples = offset_samples + (i + 1)*n_samples_per_processor; + const int n_samples_cur = (i == n_processors - 2) ? n_samples - start_samples : n_samples_per_processor; + + auto params_cur = params; + + params_cur.offset_ms = 0; + params_cur.print_progress = false; + params_cur.print_realtime = false; + + params_cur.new_segment_callback = nullptr; + params_cur.new_segment_callback_user_data = nullptr; + + params_cur.progress_callback = nullptr; + params_cur.progress_callback_user_data = nullptr; + + workers[i] = std::thread(whisper_full_with_state, ctx, states[i], std::move(params_cur), samples + start_samples, n_samples_cur); + } + + { + auto params_cur = params; + + // We need to disable the print real-time for this one as well, otherwise it will show only for the first chunk. + params_cur.print_realtime = false; + + // Run the first transformation using default state but only for the first chunk. + ret = whisper_full_with_state(ctx, ctx->state, std::move(params_cur), samples, offset_samples + n_samples_per_processor); + } + + for (int i = 0; i < n_processors - 1; ++i) { + workers[i].join(); + } + + const int64_t offset_t = (int64_t) params.offset_ms/10.0; + + // combine results into result_state->result_all from all other states + for (int i = 0; i < n_processors - 1; ++i) { + auto& results_i = states[i]->result_all; + + for (auto& result : results_i) { + // correct the segment timestamp taking into account the offset + result.t0 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + result.t1 += 100 * ((i + 1) * n_samples_per_processor) / WHISPER_SAMPLE_RATE + offset_t; + + // make sure that segments are not overlapping + if (!ctx->state->result_all.empty()) { + result.t0 = std::max(result.t0, ctx->state->result_all.back().t1); + } + + ctx->state->result_all.push_back(std::move(result)); + + // call the new_segment_callback for each segment + if (params.new_segment_callback) { + params.new_segment_callback(ctx, ctx->state, 1, params.new_segment_callback_user_data); + } + } + + ctx->state->t_mel_us += states[i]->t_mel_us; + + ctx->state->t_sample_us += states[i]->t_sample_us; + ctx->state->t_encode_us += states[i]->t_encode_us; + ctx->state->t_decode_us += states[i]->t_decode_us; + ctx->state->t_batchd_us += states[i]->t_batchd_us; + ctx->state->t_prompt_us += states[i]->t_prompt_us; + + ctx->state->n_sample += states[i]->n_sample; + ctx->state->n_encode += states[i]->n_encode; + ctx->state->n_decode += states[i]->n_decode; + ctx->state->n_batchd += states[i]->n_batchd; + ctx->state->n_prompt += states[i]->n_prompt; + + whisper_free_state(states[i]); + } + + // average the timings + ctx->state->t_mel_us /= n_processors; + ctx->state->t_sample_us /= n_processors; + ctx->state->t_encode_us /= n_processors; + ctx->state->t_decode_us /= n_processors; + + // print information about the audio boundaries + WHISPER_LOG_WARN("\n"); + WHISPER_LOG_WARN("%s: the audio has been split into %d chunks at the following times:\n", __func__, n_processors); + for (int i = 0; i < n_processors - 1; ++i) { + WHISPER_LOG_WARN("%s: split %d - %s\n", __func__, (i + 1), to_timestamp(100*((i + 1)*n_samples_per_processor)/WHISPER_SAMPLE_RATE + offset_t).c_str()); + } + WHISPER_LOG_WARN("%s: the transcription quality may be degraded near these boundaries\n", __func__); + + return ret; +} + +int whisper_full_n_segments_from_state(struct whisper_state * state) { + return state->result_all.size(); +} + +int whisper_full_n_segments(struct whisper_context * ctx) { + return ctx->state->result_all.size(); +} + +int whisper_full_lang_id_from_state(struct whisper_state * state) { + return state->lang_id; +} + +int whisper_full_lang_id(struct whisper_context * ctx) { + return ctx->state->lang_id; +} + +int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].t0; +} + +int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].t1; +} + +bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].speaker_turn_next; +} + +bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].speaker_turn_next; +} + +const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +// ================================================================================================= + +// +// Temporary interface needed for exposing ggml interface +// Will be removed in the future when ggml becomes a separate library +// + +WHISPER_API int whisper_bench_memcpy(int n_threads) { + fputs(whisper_bench_memcpy_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + + ggml_time_init(); + + size_t n = 20; + size_t arr = n_threads > 0 ? 1024llu : n_threads; // trick to avoid compiler optimizations + + // 1GB array + const size_t size = arr*1e6; + + double sum = 0.0; + + // heat-up + { + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[rand() % size] = rand() % 256; + } + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (heat-up)\n", (double) (n*size)/(tsum*1e9)); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) sum += dst[i]; + } + + free(src); + free(dst); + } + + // single-thread + { + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + for (size_t i = 0; i < n; i++) { + const int64_t t0 = ggml_time_us(); + + memcpy(dst, src, size); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + src[rand() % size] = rand() % 256; + } + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s ( 1 thread)\n", (double) (n*size)/(tsum*1e9)); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) sum += dst[i]; + } + + free(src); + free(dst); + } + + // multi-thread + + for (int32_t k = 1; k <= n_threads; k++) { + char * src = (char *) malloc(size); + char * dst = (char *) malloc(size); + + for (size_t i = 0; i < size; i++) src[i] = i; + + memcpy(dst, src, size); // heat-up + + double tsum = 0.0; + + auto helper = [&](int th) { + const int64_t i0 = (th + 0)*size/k; + const int64_t i1 = (th + 1)*size/k; + + for (size_t i = 0; i < n; i++) { + memcpy(dst + i0, src + i0, i1 - i0); + + src[i0 + rand() % (i1 - i0)] = rand() % 256; + }; + }; + + const int64_t t0 = ggml_time_us(); + + std::vector threads(k - 1); + for (int32_t th = 0; th < k - 1; ++th) { + threads[th] = std::thread(helper, th); + } + + helper(k - 1); + + for (int32_t th = 0; th < k - 1; ++th) { + threads[th].join(); + } + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + + snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k); + s += strbuf; + + // needed to prevent the compiler from optimizing the memcpy away + { + for (size_t i = 0; i < size; i++) sum += dst[i]; + } + + free(src); + free(dst); + } + + snprintf(strbuf, sizeof(strbuf), "sum: %f\n", sum); + s += strbuf; + + return s.c_str(); +} + +WHISPER_API int whisper_bench_ggml_mul_mat(int n_threads) { + fputs(whisper_bench_ggml_mul_mat_str(n_threads), stderr); + return 0; +} + +WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { + static std::string s; + s = ""; + char strbuf[256]; + + ggml_time_init(); + + const int n_max = 128; + + const std::vector sizes = { + 64, 128, 256, 512, 1024, 2048, 4096, + }; + + const size_t N_max = sizes.back(); + + // a: N*N*sizeof(float) + // b: N*N*sizeof(float) + // c: N*N*sizeof(float) + // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) + std::vector buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead()); + std::vector work; + + // put a bunch of random data in the buffer + for (size_t i = 0; i < buf.size(); i++) buf[i] = i; + + for (int j = 0; j < (int) sizes.size(); j++) { + int n_q4_0 = 0; + int n_q4_1 = 0; + int n_q5_0 = 0; + int n_q5_1 = 0; + int n_q8_0 = 0; + int n_fp16 = 0; + int n_fp32 = 0; + + // GFLOPS/s + double s_q4_0 = 0.0; + double s_q4_1 = 0.0; + double s_q5_0 = 0.0; + double s_q5_1 = 0.0; + double s_q8_0 = 0.0; + double s_fp16 = 0.0; + double s_fp32 = 0.0; + + const size_t N = sizes[j]; + + for (int k = 0; k < 7; ++k) { + const ggml_type wtype = + k == 0 ? GGML_TYPE_Q4_0 : + k == 1 ? GGML_TYPE_Q4_1 : + k == 2 ? GGML_TYPE_Q5_0 : + k == 3 ? GGML_TYPE_Q5_1 : + k == 4 ? GGML_TYPE_Q8_0 : + k == 5 ? GGML_TYPE_F16 : GGML_TYPE_F32; + + double & s = k == 0 ? s_q4_0 : k == 1 ? s_q4_1 : k == 2 ? s_q5_0 : k == 3 ? s_q5_1 : k == 4 ? s_q8_0 : k == 5 ? s_fp16 : /*k == 6*/ s_fp32; + int & n = k == 0 ? n_q4_0 : k == 1 ? n_q4_1 : k == 2 ? n_q5_0 : k == 3 ? n_q5_1 : k == 4 ? n_q8_0 : k == 5 ? n_fp16 : /*k == 6*/ n_fp32; + + struct ggml_init_params gparams = { + /*.mem_size =*/ buf.size(), + /*.mem_buffer =*/ buf.data(), + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(gparams); + + struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); + struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); + + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + + ggml_build_forward_expand(gf, c); + + double tsum = 0.0; + + // heat-up + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); + + for (int i = 0; i < n_max; ++i) { + const int64_t t0 = ggml_time_us(); + + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); + + const int64_t t1 = ggml_time_us(); + + tsum += (t1 - t0)*1e-6; + n++; + + if (tsum > 1.0 && n >= 3) { + break; + } + } + + ggml_free(ctx0); + + s = ((2.0*N*N*N*n)/tsum)*1e-9; + } + + // Q4_0 | Q4_1 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q4_0 %7.1f GFLOPS (%3d runs) | Q4_1 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q4_0, n_q4_0, s_q4_1, n_q4_1); + s += strbuf; + + // Q5_0 | Q5_1 | Q8_0 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: Q5_0 %7.1f GFLOPS (%3d runs) | Q5_1 %7.1f GFLOPS (%3d runs) | Q8_0 %7.1f GFLOPS (%3d runs)\n", + N, N, s_q5_0, n_q5_0, s_q5_1, n_q5_1, s_q8_0, n_q8_0); + s += strbuf; + + // F16 | F32 + snprintf(strbuf, sizeof(strbuf), "%4zu x %4zu: F16 %7.1f GFLOPS (%3d runs) | F32 %7.1f GFLOPS (%3d runs)\n", + N, N, s_fp16, n_fp16, s_fp32, n_fp32); + s += strbuf; + } + + return s.c_str(); +} + +// ================================================================================================= + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} +int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100ll*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (char c : text) { + if (c == ' ') { + res += 0.01f; + } else if (c == ',') { + res += 2.00f; + } else if (c == '.') { + res += 3.00f; + } else if (c == '!') { + res += 3.00f; + } else if (c == '?') { + res += 3.00f; + } else if (c >= '0' && c <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context & ctx, + struct whisper_state & state, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = state.result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = state.energy.size(); + + if (n_samples == 0) { + WHISPER_LOG_ERROR("%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = state.t_beg; + auto & t_last = state.t_last; + auto & tid_last = state.tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(&ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(&ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(&ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(&ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + //printf("p0=%d p1=%d t0=%lld t1=%lld\n", p0, p1, tokens[p0].t0, tokens[p1].t1); + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(&ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += state.energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (state.energy[k] > thold && j > 0) { + while (k > 0 && state.energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (state.energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (state.energy[k] > thold) { + while (k < n_samples - 1 && state.energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (state.energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(&ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(&ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(&ctx)) { + // continue; + // } + //} +} + +// +// token level timestamps - dtw version +// + +// n_text_layer -> total text layers on model +// n_head -> total heads per text layer on model +static std::vector get_alignment_heads_by_layer(const whisper_context_params & cparams, int il, int n_text_layer, int n_head) { + std::vector ret; + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { + return ret; + } else if (cparams.dtw_aheads_preset == WHISPER_AHEADS_N_TOP_MOST) { + if (il >= n_text_layer - cparams.dtw_n_top) { + for (int32_t i = 0; i < n_head; ++i) { + ret.push_back(i); + } + } + } else { + const auto aheads = cparams.dtw_aheads_preset == WHISPER_AHEADS_CUSTOM ? cparams.dtw_aheads : g_aheads.at(cparams.dtw_aheads_preset); + for (size_t i = 0; i < aheads.n_heads; ++i) { + if (aheads.heads[i].n_text_layer == il) { + ret.push_back(aheads.heads[i].n_head); + } + } + } + return ret; +} + +// dtw + backtrace to return found path +// based on +// https://github.com/openai/whisper/blob/main/whisper/timing.py#L83 +static ggml_tensor * dtw_and_backtrace(ggml_context * ctx, ggml_tensor * x) { + WHISPER_ASSERT(ggml_n_dims(x) == 2); + + int64_t N = x->ne[0]; + int64_t M = x->ne[1]; + struct ggml_tensor * cost = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, N + 1, M + 1); + struct ggml_tensor * trace = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, N + 1, M + 1); + + cost = ggml_set_f32(cost, INFINITY); + trace = ggml_set_f32(trace, -1); + ggml_set_f32_nd(cost, 0, 0, 0, 0, 0.0); + + // dtw + // supposedly can be optmized by computing diagonals in parallel ? + // Not sure it is worth it since x will be GENERATED_TOKENS*1500 size at most. + for (int64_t j = 1; j < M + 1; ++j) { + for (int64_t i = 1; i < N + 1; ++i) { + float c0 = ggml_get_f32_nd(cost, i - 1, j - 1, 0, 0); + float c1 = ggml_get_f32_nd(cost, i - 1, j, 0, 0); + float c2 = ggml_get_f32_nd(cost, i, j - 1, 0, 0); + + float c; + int32_t t; + if (c0 < c1 && c0 < c2) { + c = c0; + t = 0; + } else if (c1 < c0 && c1 < c2) { + c = c1; + t = 1; + } else { + c = c2; + t = 2; + } + + c = ggml_get_f32_nd(x, i - 1, j - 1, 0, 0) + c; + ggml_set_f32_nd(cost, i, j, 0, 0, c); + ggml_set_i32_nd(trace, i, j, 0, 0, t); + } + } + + // Backtrace + const int64_t BT_MAX_ROWS = N + M - 1; + struct ggml_tensor * bt = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, BT_MAX_ROWS, 2); + // trace[0, :] = 2; + for (int64_t i = 0; i < M + 1; ++i) + ggml_set_i32_nd(trace, 0, i, 0, 0, 2); + //trace[:, 0] = 1; + for (int64_t i = 0; i < N + 1; ++i) + ggml_set_i32_nd(trace, i, 0, 0, 0, 1); + int bt_row_idx = BT_MAX_ROWS - 1; + int64_t i = N; + int64_t j = M; + while (i > 0 || j > 0) { + ggml_set_i32_nd(bt, bt_row_idx, 0, 0, 0, i - 1); + ggml_set_i32_nd(bt, bt_row_idx, 1, 0, 0, j - 1); + --bt_row_idx; + + int32_t t = ggml_get_i32_nd(trace, i, j, 0, 0); + if (t == 0) { + --i; + --j; + } else if (t == 1) { + --i; + } else if (t == 2) { + --j; + } else { + WHISPER_ASSERT(0); + } + } + + // FIXME: manual clip/transpose might not be the most efficient way? (e.g. use ggml funcs) + // Clip + transpose + // This might not be entirely necessary for our case, but leaving it for now so output matrix + // is identical to dtw on openAI timing.py + const int64_t result_n_cols = BT_MAX_ROWS-bt_row_idx-1; + ggml_tensor * r = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, 2, result_n_cols); + for (int64_t i = 0; i < 2; ++i) { + for (int64_t j = 0; j < result_n_cols; ++j) { + int32_t v = ggml_get_i32_nd(bt, j+bt_row_idx+1, i, 0, 0); + ggml_set_i32_nd(r, i, j, 0, 0, v); + } + } + + return r; +} + +struct median_filter_user_data { + int filter_width; +}; + +static void median_filter(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata) { + int filter_width = ((median_filter_user_data *) userdata)->filter_width; + WHISPER_ASSERT(nth == 1); + WHISPER_ASSERT(ith == 0); + WHISPER_ASSERT(filter_width < a->ne[2]); + WHISPER_ASSERT(filter_width % 2); + WHISPER_ASSERT(ggml_n_dims(a) == 3); + WHISPER_ASSERT(a->type == GGML_TYPE_F32); + + std::vector filter; + filter.reserve(filter_width); + for (int64_t i = 0; i < a->ne[0]; ++i) { + for (int64_t j = 0; j < a->ne[1]; ++j) { + for (int64_t k = 0; k < a->ne[2]; ++k) { + for (int64_t off = -filter_width/2; off <= filter_width/2; ++off) { + // "reflect" padding + int64_t idx = k + off; + if (idx < 0) { + idx = -idx; + } else if (idx >= a->ne[2]) { + idx = 2*(a->ne[2] - 1) - idx; + } + + filter.push_back(ggml_get_f32_nd(a, i, j, idx, 0)); + } + std::sort(filter.begin(), filter.end()); + const float v = filter[filter.size()/2]; + ggml_set_f32_nd(dst, i, j, k, 0, v); + filter.clear(); + } + } + } +} + +static void whisper_exp_compute_token_level_timestamps_dtw( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + int i_segment, + size_t n_segments, + int seek, + int n_frames, + int medfilt_width, + int n_threads) +{ + const int n_audio_ctx = state->exp_n_audio_ctx > 0 ? state->exp_n_audio_ctx : ctx->model.hparams.n_audio_ctx; + WHISPER_ASSERT(medfilt_width % 2); + WHISPER_ASSERT(n_frames <= n_audio_ctx * 2); + WHISPER_ASSERT(ctx->params.dtw_aheads_preset != WHISPER_AHEADS_NONE); + + // FIXME: Allocating mem everytime we call this func + // Our ggml buffer should be pre-allocated somewhere during init and reused + // when we call this function + struct ggml_init_params gparams = { + /*.mem_size =*/ ctx->params.dtw_mem_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + struct ggml_context * gctx = ggml_init(gparams); + + // Build token sequence that will be passed to decoder + // sot + [lang] + text result + eot + std::vector tokens = { whisper_token_sot(ctx), }; + if (whisper_is_multilingual(ctx)) { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + tokens.push_back(whisper_token_lang(ctx, lang_id)); + } + const size_t sot_sequence_length = tokens.size(); + tokens.push_back(whisper_token_not(ctx)); + for (size_t i = i_segment; i < i_segment + n_segments; ++i) { + auto & segment = state->result_all[i]; + for (auto &t: segment.tokens) { + // Only text tokens + if (t.id < whisper_token_eot(ctx)) { + tokens.push_back(t.id); + } + } + } + tokens.push_back(whisper_token_eot(ctx)); + + // Get result tokens, pass then along to decoder to get cross attention QKs + // used in timestamping + // Decoder already returns only alignment head QKs, already concatenated in + // one tensor. + whisper_kv_cache_clear(state->kv_self); + whisper_batch_prep_legacy(state->batch, tokens.data(), tokens.size(), 0, 0); + whisper_kv_cache_seq_rm(state->kv_self, 0, 0, -1); + if (!whisper_decode_internal(*ctx, *state, state->batch, n_threads, true, nullptr, nullptr)) { + WHISPER_LOG_INFO("DECODER FAILED\n"); + WHISPER_ASSERT(0); + } + WHISPER_ASSERT(state->aheads_cross_QKs != nullptr); + + const auto n_audio_tokens = n_frames/2; + WHISPER_ASSERT(state->aheads_cross_QKs != NULL); + WHISPER_ASSERT(n_audio_tokens <= state->aheads_cross_QKs->ne[1]); + const auto n_tokens = state->aheads_cross_QKs->ne[0]; + const auto n_heads = state->aheads_cross_QKs->ne[2]; + + // Copy data from decoder buffer to a local CPU tensor, discarding unused audio + // tokens (i.e. discarding rows at the end of tensor) + // IN: Tensor with N_TOKENS*audio_ctx*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + WHISPER_ASSERT(state->aheads_cross_QKs->type == GGML_TYPE_F32); + WHISPER_ASSERT(ggml_is_contiguous(state->aheads_cross_QKs)); + ggml_tensor * w = ggml_new_tensor_3d(gctx, GGML_TYPE_F32, n_tokens, n_audio_tokens, n_heads); + auto & data = state->aheads_cross_QKs_data; + data.resize(n_tokens * n_audio_ctx * n_heads); + ggml_backend_tensor_get(state->aheads_cross_QKs, data.data(), 0, sizeof(float) * n_tokens * n_audio_ctx * n_heads); + for (int k = 0; k < n_heads; ++k) { + for (int j = 0; j < n_audio_tokens; ++j) { + memcpy( + (char *) w->data + j * w->nb[1] + k * w->nb[2], + data.data() + j * n_tokens + k * n_tokens * n_audio_ctx, + n_tokens * sizeof(float) + ); + } + } + + // Normalize - in original OpenAI code, this is done over dim=-2. In this case, + // we already permuted N_TOKENS dimension to columns on last loop, becase ggml_norm + // operates over columns. Afterwards, permute to a shape that facilitates mean + // operation (after median filter) + // IN: Tensor with N_TOKENS*N_AUDIO_TOKENS*N_ALIGNMENT_HEADS dims + // OUT: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_norm(gctx, w, 1e-9); + w = ggml_permute(gctx, ggml_permute(gctx, w, 2, 1, 0 ,3), 0, 2, 1, 3); + + // Pass median filter - this is done over AUDIO_TOKENS dimension. + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Same dims + median_filter_user_data mf_user_data = {medfilt_width}; + w = ggml_map_custom1(gctx, w, median_filter, 1, &mf_user_data); + + // Take mean over columns, scale by -1, reshape to 2D tensor, remove SOT sequence and EOT + // IN: Tensor with N_ALIGNMENT_HEADS*N_TOKENS*N_AUDIO_TOKENS dims + // OUT: Tensor with N_TOKENS*N_AUDIO_TOKENS dims + w = ggml_mean(gctx, w); + w = ggml_scale(gctx, w, -1.0); + w = ggml_reshape_2d(gctx, w, w->ne[1], w->ne[2]); + + // Remove SOT sequence and EOT + // Out dimension is (N_TOKENS-sot_sequence_length-1)*N_AUDIO_TOKENS + w = ggml_view_2d(gctx, w, w->ne[0] - sot_sequence_length - 1, w->ne[1], w->nb[1], sot_sequence_length * w->nb[0]); + + // Compute + struct ggml_cgraph * gf = ggml_new_graph(gctx); + ggml_build_forward_expand(gf, w); + ggml_graph_compute_with_ctx(gctx, gf, n_threads); + + ggml_tensor * alignment = dtw_and_backtrace(gctx, w); + + // Place timestamps on segments + int32_t last_v = 0; + auto seg_i = state->result_all.begin() + i_segment; + auto tok_i = seg_i->tokens.begin(); + for (int i = 0; i < alignment->ne[1]; ++i) { + int32_t v = ggml_get_i32_nd(alignment, 0, i, 0, 0); + if (v != last_v) { + int32_t time_index = ggml_get_i32_nd(alignment, 1, i, 0, 0); + int64_t timestamp = (time_index * 2) + seek; // Each index on DTW result = 20mS audio + last_v = v; + + // Skip non-text tokens + while (!(tok_i->id < whisper_token_eot(ctx))) { + ++tok_i; + if (tok_i == seg_i->tokens.end()) { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + + tok_i->t_dtw = timestamp; + ++tok_i; + if (tok_i == seg_i->tokens.end()) { + ++seg_i; + tok_i = seg_i->tokens.begin(); + } + } + } + + // Print DTW timestamps + /*for (size_t i = i_segment; i < i_segment + n_segments; ++i) { + auto & segment = state->result_all[i]; + for (auto &t: segment.tokens) { + const char * tok = whisper_token_to_str(ctx, t.id); + fprintf(stderr, "|%s|(%.2f) ", tok, (float)t.t_dtw/100); + } + fprintf(stderr, "\n"); + }*/ + + ggml_free(gctx); +} + +void whisper_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : whisper_log_callback_default; + g_state.log_callback_user_data = user_data; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void whisper_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void whisper_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + + + +bool is_file_exist(const char *fileName) +{ + std::ifstream infile(fileName); + return infile.good(); +} diff --git a/otherarch/whispercpp/whisper.h b/otherarch/whispercpp/whisper.h new file mode 100644 index 000000000..a7a93fd2d --- /dev/null +++ b/otherarch/whispercpp/whisper.h @@ -0,0 +1,681 @@ +#ifndef WHISPER_H +#define WHISPER_H + +#include "ggml.h" + +#include +#include +#include +#include + +#ifdef __GNUC__ +# define WHISPER_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define WHISPER_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define WHISPER_DEPRECATED(func, hint) func +#endif + +#ifdef WHISPER_SHARED +# ifdef _WIN32 +# ifdef WHISPER_BUILD +# define WHISPER_API __declspec(dllexport) +# else +# define WHISPER_API __declspec(dllimport) +# endif +# else +# define WHISPER_API __attribute__ ((visibility ("default"))) +# endif +#else +# define WHISPER_API +#endif + +#define WHISPER_SAMPLE_RATE 16000 +#define WHISPER_N_FFT 400 +#define WHISPER_HOP_LENGTH 160 +#define WHISPER_CHUNK_SIZE 30 + +#ifdef __cplusplus +extern "C" { +#endif + + // + // C interface + // + // The following interface is thread-safe as long as the sample whisper_context is not used by multiple threads + // concurrently. + // + // Basic usage: + // + // #include "whisper.h" + // + // ... + // + // whisper_context_params cparams = whisper_context_default_params(); + // + // struct whisper_context * ctx = whisper_init_from_file_with_params("/path/to/ggml-base.en.bin", cparams); + // + // if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + // fprintf(stderr, "failed to process audio\n"); + // return 7; + // } + // + // const int n_segments = whisper_full_n_segments(ctx); + // for (int i = 0; i < n_segments; ++i) { + // const char * text = whisper_full_get_segment_text(ctx, i); + // printf("%s", text); + // } + // + // whisper_free(ctx); + // + // ... + // + // This is a demonstration of the most straightforward usage of the library. + // "pcmf32" contains the RAW audio data in 32-bit floating point format. + // + // The interface also allows for more fine-grained control over the computation, but it requires a deeper + // understanding of how the model works. + // + + struct whisper_context; + struct whisper_state; + struct whisper_full_params; + + typedef int32_t whisper_pos; + typedef int32_t whisper_token; + typedef int32_t whisper_seq_id; + + enum whisper_alignment_heads_preset { + WHISPER_AHEADS_NONE, + WHISPER_AHEADS_N_TOP_MOST, // All heads from the N-top-most text-layers + WHISPER_AHEADS_CUSTOM, + WHISPER_AHEADS_TINY_EN, + WHISPER_AHEADS_TINY, + WHISPER_AHEADS_BASE_EN, + WHISPER_AHEADS_BASE, + WHISPER_AHEADS_SMALL_EN, + WHISPER_AHEADS_SMALL, + WHISPER_AHEADS_MEDIUM_EN, + WHISPER_AHEADS_MEDIUM, + WHISPER_AHEADS_LARGE_V1, + WHISPER_AHEADS_LARGE_V2, + WHISPER_AHEADS_LARGE_V3, + }; + + typedef struct whisper_ahead { + int n_text_layer; + int n_head; + } whisper_ahead; + + typedef struct whisper_aheads { + size_t n_heads; + const whisper_ahead * heads; + } whisper_aheads; + + struct whisper_context_params { + bool use_gpu; + bool flash_attn; + int gpu_device; // CUDA device + + // [EXPERIMENTAL] Token-level timestamps with DTW + bool dtw_token_timestamps; + enum whisper_alignment_heads_preset dtw_aheads_preset; + + int dtw_n_top; + struct whisper_aheads dtw_aheads; + + size_t dtw_mem_size; // TODO: remove + }; + + typedef struct whisper_token_data { + whisper_token id; // token id + whisper_token tid; // forced timestamp token id + + float p; // probability of the token + float plog; // log probability of the token + float pt; // probability of the timestamp token + float ptsum; // sum of probabilities of all timestamp tokens + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + // [EXPERIMENTAL] Token-level timestamps with DTW + // do not use if you haven't computed token-level timestamps with dtw + // Roughly corresponds to the moment in audio in which the token was output + int64_t t_dtw; + + float vlen; // voice length of the token + } whisper_token_data; + + typedef struct whisper_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } whisper_model_loader; + + // grammar element type + enum whisper_gretype { + // end of rule definition + WHISPER_GRETYPE_END = 0, + + // start of alternate definition for rule + WHISPER_GRETYPE_ALT = 1, + + // non-terminal element: reference to rule + WHISPER_GRETYPE_RULE_REF = 2, + + // terminal element: character (code point) + WHISPER_GRETYPE_CHAR = 3, + + // inverse char(s) ([^a], [^a-b] [^abc]) + WHISPER_GRETYPE_CHAR_NOT = 4, + + // modifies a preceding WHISPER_GRETYPE_CHAR or LLAMA_GRETYPE_CHAR_ALT to + // be an inclusive range ([a-z]) + WHISPER_GRETYPE_CHAR_RNG_UPPER = 5, + + // modifies a preceding WHISPER_GRETYPE_CHAR or + // WHISPER_GRETYPE_CHAR_RNG_UPPER to add an alternate char to match ([ab], [a-zA]) + WHISPER_GRETYPE_CHAR_ALT = 6, + }; + + typedef struct whisper_grammar_element { + enum whisper_gretype type; + uint32_t value; // Unicode code point or rule ID + } whisper_grammar_element; + + // Various functions for loading a ggml whisper model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + WHISPER_API struct whisper_context * whisper_init_from_file_with_params (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params (struct whisper_model_loader * loader, struct whisper_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using whisper_init_state() (#523) + WHISPER_API struct whisper_context * whisper_init_from_file_with_params_no_state (const char * path_model, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params); + WHISPER_API struct whisper_context * whisper_init_with_params_no_state (struct whisper_model_loader * loader, struct whisper_context_params params); + + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file(const char * path_model), + "use whisper_init_from_file_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init(struct whisper_model_loader * loader), + "use whisper_init_with_params instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_file_no_state(const char * path_model), + "use whisper_init_from_file_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size), + "use whisper_init_from_buffer_with_params_no_state instead" + ); + WHISPER_DEPRECATED( + WHISPER_API struct whisper_context * whisper_init_no_state(struct whisper_model_loader * loader), + "use whisper_init_with_params_no_state instead" + ); + + WHISPER_API struct whisper_state * whisper_init_state(struct whisper_context * ctx); + + // Given a context, enable use of OpenVINO for encode inference. + // model_path: Optional path to OpenVINO encoder IR model. If set to nullptr, + // the path will be generated from the ggml model path that was passed + // in to whisper_init_from_file. For example, if 'path_model' was + // "/path/to/ggml-base.en.bin", then OpenVINO IR model path will be + // assumed to be "/path/to/ggml-base.en-encoder-openvino.xml". + // device: OpenVINO device to run inference on ("CPU", "GPU", etc.) + // cache_dir: Optional cache directory that can speed up init time, especially for + // GPU, by caching compiled 'blobs' there. + // Set to nullptr if not used. + // Returns 0 on success. If OpenVINO is not enabled in build, this simply returns 1. + WHISPER_API int whisper_ctx_init_openvino_encoder( + struct whisper_context * ctx, + const char * model_path, + const char * device, + const char * cache_dir); + + // Frees all allocated memory + WHISPER_API void whisper_free (struct whisper_context * ctx); + WHISPER_API void whisper_free_state(struct whisper_state * state); + WHISPER_API void whisper_free_params(struct whisper_full_params * params); + WHISPER_API void whisper_free_context_params(struct whisper_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // Convert RAW PCM audio to log mel spectrogram but applies a Phase Vocoder to speed up the audio x2. + // The resulting spectrogram is stored inside the default state of the provided whisper context. + // Returns 0 on success + WHISPER_API int whisper_pcm_to_mel_phase_vocoder( + struct whisper_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + WHISPER_API int whisper_pcm_to_mel_phase_vocoder_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided whisper context. + // Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 80 + // Returns 0 on success + WHISPER_API int whisper_set_mel( + struct whisper_context * ctx, + const float * data, + int n_len, + int n_mel); + + WHISPER_API int whisper_set_mel_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Whisper encoder on the log mel spectrogram stored inside the default state in the provided whisper context. + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + WHISPER_API int whisper_encode( + struct whisper_context * ctx, + int offset, + int n_threads); + + WHISPER_API int whisper_encode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset, + int n_threads); + + // Run the Whisper decoder to obtain the logits and probabilities for the next token. + // Make sure to call whisper_encode() first. + // tokens + n_tokens is the provided context for the decoder. + // n_past is the number of tokens to use from previous decoder calls. + // Returns 0 on success + // TODO: add support for multiple decoders + WHISPER_API int whisper_decode( + struct whisper_context * ctx, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + WHISPER_API int whisper_decode_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token * tokens, + int n_tokens, + int n_past, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + WHISPER_API int whisper_tokenize( + struct whisper_context * ctx, + const char * text, + whisper_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -whisper_tokenize(ctx, text, NULL, 0) + int whisper_token_count(struct whisper_context * ctx, const char * text); + + // Largest language id (i.e. number of available languages - 1) + WHISPER_API int whisper_lang_max_id(); + + // Return the id of the specified language, returns -1 if not found + // Examples: + // "de" -> 2 + // "german" -> 2 + WHISPER_API int whisper_lang_id(const char * lang); + + // Return the short string of the specified language id (e.g. 2 -> "de"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str(int id); + + // Return the short string of the specified language name (e.g. 2 -> "german"), returns nullptr if not found + WHISPER_API const char * whisper_lang_str_full(int id); + + // Use mel data at offset_ms to try and auto-detect the spoken language + // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first + // Returns the top language id or negative on failure + // If not null, fills the lang_probs array with the probabilities of all languages + // The array must be whisper_lang_max_id() + 1 in size + // ref: https://github.com/openai/whisper/blob/main/whisper/decoding.py#L18-L69 + WHISPER_API int whisper_lang_auto_detect( + struct whisper_context * ctx, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_lang_auto_detect_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + int offset_ms, + int n_threads, + float * lang_probs); + + WHISPER_API int whisper_n_len (struct whisper_context * ctx); // mel length + WHISPER_API int whisper_n_len_from_state(struct whisper_state * state); // mel length + WHISPER_API int whisper_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_is_multilingual (struct whisper_context * ctx); + + WHISPER_API int whisper_model_n_vocab (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_state(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_audio_layer(struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_ctx (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_state (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_head (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_text_layer (struct whisper_context * ctx); + WHISPER_API int whisper_model_n_mels (struct whisper_context * ctx); + WHISPER_API int whisper_model_ftype (struct whisper_context * ctx); + WHISPER_API int whisper_model_type (struct whisper_context * ctx); + + // Token logits obtained from the last call to whisper_decode() + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + WHISPER_API float * whisper_get_logits (struct whisper_context * ctx); + WHISPER_API float * whisper_get_logits_from_state(struct whisper_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token); + WHISPER_API const char * whisper_model_type_readable(struct whisper_context * ctx); + + + // Special tokens + WHISPER_API whisper_token whisper_token_eot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_sot (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_solm(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_prev(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_nosp(struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_not (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_beg (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id); + + // Task tokens + WHISPER_API whisper_token whisper_token_translate (struct whisper_context * ctx); + WHISPER_API whisper_token whisper_token_transcribe(struct whisper_context * ctx); + + // Performance information from the default state. + WHISPER_API void whisper_print_timings(struct whisper_context * ctx); + WHISPER_API void whisper_reset_timings(struct whisper_context * ctx); + + // Print system information + WHISPER_API const char * whisper_print_system_info(void); + + //////////////////////////////////////////////////////////////////////////// + + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreedyDecoder + WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder + }; + + // Text segment callback + // Called on every newly generated text segment + // Use the whisper_full_...() functions to obtain the text segments + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*whisper_progress_callback)(struct whisper_context * ctx, struct whisper_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, struct whisper_state * state, void * user_data); + + // Logits filter callback + // Can be used to modify the logits before sampling + // If not NULL, called after applying temperature to logits + typedef void (*whisper_logits_filter_callback)( + struct whisper_context * ctx, + struct whisper_state * state, + const whisper_token_data * tokens, + int n_tokens, + float * logits, + void * user_data); + + // Parameters for the whisper_full() function + // If you change the order or add new parameters, make sure to update the default values in whisper.cpp: + // whisper_full_default_params() + struct whisper_full_params { + enum whisper_sampling_strategy strategy; + + int n_threads; + int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool translate; + bool no_context; // do not use past transcription (if any) as initial prompt for the decoder + bool no_timestamps; // do not generate timestamps + bool single_segment; // force single segment output (useful for streaming) + bool print_special; // print special tokens (e.g. , , , etc.) + bool print_progress; // print progress information + bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead) + bool print_timestamps; // print timestamps for each text segment when printing realtime + + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + bool split_on_word; // split on word rather than on token (when used with max_len) + int max_tokens; // max tokens per segment (0 = no limit) + + // [EXPERIMENTAL] speed-up techniques + // note: these can significantly reduce the quality of the output + bool speed_up; // speed-up the audio by 2x using Phase Vocoder + bool debug_mode; // enable debug_mode provides extra info (eg. Dump log_mel) + int audio_ctx; // overwrite the audio context size (0 = use default) + + // [EXPERIMENTAL] [TDRZ] tinydiarize + bool tdrz_enable; // enable tinydiarize speaker turn detection + + // A regular expression that matches tokens to suppress + const char * suppress_regex; + + // tokens to provide to the whisper decoder as initial prompt + // these are prepended to any existing text context from a previous call + // use whisper_tokenize() to convert text to tokens + // maximum of whisper_n_text_ctx()/2 tokens are used (typically 224) + const char * initial_prompt; + const whisper_token * prompt_tokens; + int prompt_n_tokens; + + // for auto-detection, set to nullptr, "" or "auto" + const char * language; + bool detect_language; + + // common decoding parameters: + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + + float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 + float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + + // fallback parameters + // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + float temperature_inc; + float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" + float logprob_thold; + float no_speech_thold; // TODO: not implemented + + struct { + int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 + } greedy; + + struct { + int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265 + + float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf + } beam_search; + + // called for every newly generated text segment + whisper_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called on each progress update + whisper_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + whisper_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + + // called by each decoder to filter obtained logits + whisper_logits_filter_callback logits_filter_callback; + void * logits_filter_callback_user_data; + + const whisper_grammar_element ** grammar_rules; + size_t n_grammar_rules; + size_t i_start_rule; + float grammar_penalty; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see whisper_free_context_params & whisper_free_params() + WHISPER_API struct whisper_context_params * whisper_context_default_params_by_ref(); + WHISPER_API struct whisper_context_params whisper_context_default_params(void); + WHISPER_API struct whisper_full_params * whisper_full_default_params_by_ref(enum whisper_sampling_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + // Uses the specified decoding strategy to obtain the text. + WHISPER_API int whisper_full( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples); + + WHISPER_API int whisper_full_with_state( + struct whisper_context * ctx, + struct whisper_state * state, + struct whisper_full_params params, + const float * samples, + int n_samples); + + // Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + // Result is stored in the default state of the context + // Not thread safe if executed in parallel on the same context. + // It seems this approach can offer some speedup in some cases. + // However, the transcription accuracy can be worse at the beginning and end of each chunk. + WHISPER_API int whisper_full_parallel( + struct whisper_context * ctx, + struct whisper_full_params params, + const float * samples, + int n_samples, + int n_processors); + + // Number of generated text segments + // A segment can be a few words, a sentence, or even a paragraph. + WHISPER_API int whisper_full_n_segments (struct whisper_context * ctx); + WHISPER_API int whisper_full_n_segments_from_state(struct whisper_state * state); + + // Language id associated with the context's default state + WHISPER_API int whisper_full_lang_id(struct whisper_context * ctx); + + // Language id associated with the provided state + WHISPER_API int whisper_full_lang_id_from_state(struct whisper_state * state); + + // Get the start and end time of the specified segment + WHISPER_API int64_t whisper_full_get_segment_t0 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t0_from_state(struct whisper_state * state, int i_segment); + + WHISPER_API int64_t whisper_full_get_segment_t1 (struct whisper_context * ctx, int i_segment); + WHISPER_API int64_t whisper_full_get_segment_t1_from_state(struct whisper_state * state, int i_segment); + + // Get whether the next segment is predicted as a speaker turn + WHISPER_API bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment); + WHISPER_API bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment); + + // Get the text of the specified segment + WHISPER_API const char * whisper_full_get_segment_text (struct whisper_context * ctx, int i_segment); + WHISPER_API const char * whisper_full_get_segment_text_from_state(struct whisper_state * state, int i_segment); + + // Get number of tokens in the specified segment + WHISPER_API int whisper_full_n_tokens (struct whisper_context * ctx, int i_segment); + WHISPER_API int whisper_full_n_tokens_from_state(struct whisper_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + WHISPER_API const char * whisper_full_get_token_text (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API const char * whisper_full_get_token_text_from_state(struct whisper_context * ctx, struct whisper_state * state, int i_segment, int i_token); + + WHISPER_API whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token whisper_full_get_token_id_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + // This contains probabilities, timestamps, etc. + WHISPER_API whisper_token_data whisper_full_get_token_data (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data_from_state(struct whisper_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + WHISPER_API float whisper_full_get_token_p (struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API float whisper_full_get_token_p_from_state(struct whisper_state * state, int i_segment, int i_token); + + //////////////////////////////////////////////////////////////////////////// + + // Temporary helpers needed for exposing ggml interface + + WHISPER_API int whisper_bench_memcpy (int n_threads); + WHISPER_API const char * whisper_bench_memcpy_str (int n_threads); + WHISPER_API int whisper_bench_ggml_mul_mat (int n_threads); + WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads); + + // Control logging output; default behavior is to print to stderr + + WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif + +//don't want to modify common.h, throw functions here +// convert timestamp to string, 6000 -> 01:00.000 +std::string to_timestamp(int64_t t, bool comma = false); +// given a timestamp get the sample +int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// check if file exists using ifstream +bool is_file_exist(const char *fileName); \ No newline at end of file diff --git a/otherarch/whispercpp/whisper_adapter.cpp b/otherarch/whispercpp/whisper_adapter.cpp new file mode 100644 index 000000000..944405da0 --- /dev/null +++ b/otherarch/whispercpp/whisper_adapter.cpp @@ -0,0 +1,243 @@ +#include "model_adapter.h" +#include "otherarch/utils.h" + +#include "whisper.cpp" + +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define COMMON_SAMPLE_RATE 16000 + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static int whisperdebugmode = 0; +static whisper_context * whisper_ctx = nullptr; +static std::string whisper_output_text = ""; + + +static bool is_wav_buffer(const std::string buf) { + // RIFF ref: https://en.wikipedia.org/wiki/Resource_Interchange_File_Format + // WAV ref: https://www.mmsp.ece.mcgill.ca/Documents/AudioFormats/WAVE/WAVE.html + if (buf.size() < 12 || buf.substr(0, 4) != "RIFF" || buf.substr(8, 4) != "WAVE") { + return false; + } + uint32_t chunk_size = *reinterpret_cast(buf.data() + 4); + if (chunk_size + 8 != buf.size()) { + return false; + } + return true; +} + +static bool read_wav(const std::string & b64data, std::vector& pcmf32, std::vector>& pcmf32s, bool stereo) +{ + drwav wav; + std::vector wav_data = kcpp_base64_decode(b64data); + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + printf("error: failed to open WAV file from stdin\n"); + return false; + } + + if (wav.channels != 1 && wav.channels != 2) { + printf("WAV file must be mono or stereo\n"); + drwav_uninit(&wav); + return false; + } + + if (wav.sampleRate != COMMON_SAMPLE_RATE) { + printf("WAV file must be %i kHz\n", COMMON_SAMPLE_RATE/1000); + drwav_uninit(&wav); + return false; + } + + if (wav.bitsPerSample != 16) { + printf("WAV file must be 16-bit\n"); + drwav_uninit(&wav); + return false; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + return true; +} + +static std::string output_txt(struct whisper_context * ctx, std::vector> pcmf32s) { + + std::string outtxt = ""; + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + outtxt += text; + } + return outtxt; +} + +void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +static std::string whisperplatformenv, whisperdeviceenv, whispervulkandeviceenv; +bool whispertype_load_model(const whisper_load_model_inputs inputs) +{ + //duplicated from expose.cpp + int cl_parseinfo = inputs.clblast_info; //first digit is whether configured, second is platform, third is devices + std::string usingclblast = "GGML_OPENCL_CONFIGURED="+std::to_string(cl_parseinfo>0?1:0); + putenv((char*)usingclblast.c_str()); + cl_parseinfo = cl_parseinfo%100; //keep last 2 digits + int platform = cl_parseinfo/10; + int devices = cl_parseinfo%10; + whisperplatformenv = "GGML_OPENCL_PLATFORM="+std::to_string(platform); + whisperdeviceenv = "GGML_OPENCL_DEVICE="+std::to_string(devices); + putenv((char*)whisperplatformenv.c_str()); + putenv((char*)whisperdeviceenv.c_str()); + std::string vulkan_info_raw = inputs.vulkan_info; + std::string vulkan_info_str = ""; + for (size_t i = 0; i < vulkan_info_raw.length(); ++i) { + vulkan_info_str += vulkan_info_raw[i]; + if (i < vulkan_info_raw.length() - 1) { + vulkan_info_str += ","; + } + } + if(vulkan_info_str=="") + { + vulkan_info_str = "0"; + } + whispervulkandeviceenv = "GGML_VK_VISIBLE_DEVICES="+vulkan_info_str; + putenv((char*)whispervulkandeviceenv.c_str()); + + + std::string modelfile = inputs.model_filename; + printf("\nLoading Whisper Model: %s",modelfile.c_str()); + + whisperdebugmode = inputs.debugmode; + if (whisperdebugmode!=1) { + whisper_log_set(cb_log_disable, NULL); + } + + // whisper init + struct whisper_context_params cparams = whisper_context_default_params(); + cparams.use_gpu = true; + cparams.flash_attn = false; + + whisper_ctx = whisper_init_from_file_with_params(modelfile.c_str(), cparams); + + if (whisper_ctx == nullptr) { + printf("\nWhisper Load Error: Failed to initialize whisper context!\n"); + return false; + } + + printf("\nWhisper Load Complete.\n"); + + return true; +} + +whisper_generation_outputs whispertype_generate(const whisper_generation_inputs inputs) +{ + whisper_generation_outputs output; + + if(whisper_ctx==nullptr) + { + printf("\nWarning: KCPP whisper not initialized!\n"); + output.text = ""; + output.status = 0; + return output; + } + + if(!inputs.quiet) + { + printf("\nWhisper Transcribe Generating..."); + } + + const std::string b64data = std::string(inputs.audio_data); + const std::string initprompt = std::string(inputs.prompt); + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + if (!::read_wav(b64data, pcmf32, pcmf32s, false)) { + printf("\nWhisper: Failed to read input wav data!\n"); + output.text = ""; + output.status = 0; + return output; + } + + // run the inference + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + wparams.strategy = WHISPER_SAMPLING_GREEDY; + wparams.print_realtime = false; + wparams.print_progress = false; + wparams.print_timestamps = false; + wparams.print_special = false; + wparams.translate = false; + wparams.language = "en"; + wparams.detect_language = false; + wparams.n_threads = 4; + wparams.n_max_text_ctx = wparams.n_max_text_ctx; + wparams.offset_ms = 0; + wparams.duration_ms = 0; + wparams.token_timestamps = false; + wparams.thold_pt = 0.01f; + wparams.max_len = 100; + wparams.split_on_word = false; + wparams.audio_ctx = 0; + wparams.speed_up = false; + wparams.debug_mode = false; + wparams.tdrz_enable = false; + wparams.suppress_regex = nullptr; + wparams.initial_prompt = initprompt.c_str(); + wparams.greedy.best_of = -1; + wparams.beam_search.beam_size = -1; + wparams.temperature_inc = 0.2f; + wparams.temperature = 0.0f; + wparams.entropy_thold = 2.40f; + wparams.logprob_thold = -1.00f; + wparams.no_timestamps = true; + + if (whisper_full_parallel(whisper_ctx, wparams, pcmf32.data(), pcmf32.size(), 1) != 0) { + printf("\nWhisper: Failed to process audio!\n"); + output.text = ""; + output.status = 0; + return output; + } + + if (!inputs.quiet && whisperdebugmode==1) { + whisper_print_timings(whisper_ctx); + } + + // output text transcription + whisper_output_text = output_txt(whisper_ctx, pcmf32s); + if(!inputs.quiet) + { + printf("\nWhisper Transcribe Output: %s",whisper_output_text.c_str()); + } + output.text = whisper_output_text.c_str(); + output.status = 1; + return output; +} \ No newline at end of file