From 5c9714cf4001de4a3b54b0d3d4f80338ae1e21a5 Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Sun, 19 Jan 2025 16:57:41 +0800 Subject: [PATCH] improve whisper to work on 8 bit and 32bit wav too, also support form data for language --- koboldcpp.py | 38 +++++++++++++++++------- otherarch/whispercpp/whisper_adapter.cpp | 22 ++++++++++++-- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/koboldcpp.py b/koboldcpp.py index 887f4e7f1..3b6d7b8bd 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -1901,7 +1901,8 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): super().log_message(format, *args) pass - def extract_b64string_from_file_upload(self, body): + def extract_transcribe_from_file_upload(self, body): + result = {"file": None, "prompt": None, "language": None} try: if 'content-type' in self.headers and self.headers['content-type']: boundary = self.headers['content-type'].split("=")[1].encode() @@ -1914,15 +1915,27 @@ class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): file_content_start = fpart.find(b'\r\n\r\n') + 4 # Position after headers file_content_end = fpart.rfind(b'\r\n') # Ending boundary if file_content_start != -1 and file_content_end != -1: - file_data = fpart[file_content_start:file_content_end] - file_data_base64 = base64.b64encode(file_data).decode('utf-8',"ignore") - base64_string = f"data:audio/wav;base64,{file_data_base64}" - return base64_string - print("Uploaded file not found.") - return None + if "file" in result and result["file"] is None: + file_data = fpart[file_content_start:file_content_end] + file_data_base64 = base64.b64encode(file_data).decode('utf-8',"ignore") + base64_string = f"data:audio/wav;base64,{file_data_base64}" + result["file"] = base64_string + + # Check for fields + detected_prompt_field = re.findall(r'Content-Disposition.*name="prompt"\r\n\r\n(.*)\r\n', fpart.decode('utf-8', errors='ignore')) + if detected_prompt_field and len(detected_prompt_field)>0: + result["prompt"] = detected_prompt_field[0].strip() # Extract and strip whitespace + + detected_lang_field = re.findall(r'Content-Disposition.*name="language"\r\n\r\n(.*)\r\n', fpart.decode('utf-8', errors='ignore')) + if detected_lang_field and len(detected_lang_field)>0: + result["language"] = detected_lang_field[0].strip() # Extract and strip whitespace + + if not ("file" in result and result["file"]): + print("Uploaded file not found.") + return result except Exception as e: print(f"File Upload Process Error: {e}") - return None + return result async def generate_text(self, genparams, api_format, stream_flag): global friendlymodelname, chatcompl_adapter, currfinishreason @@ -2742,9 +2755,14 @@ Enter Prompt:
except Exception: genparams = None if is_transcribe: #fallback handling of file uploads - b64wav = self.extract_b64string_from_file_upload(body) - if b64wav: + formdata = self.extract_transcribe_from_file_upload(body) + if "file" in formdata and formdata["file"]: + b64wav = formdata["file"] genparams = {"audio_data":b64wav} + if "prompt" in formdata and formdata["prompt"]: + genparams["prompt"] = formdata["prompt"] + if "language" in formdata and formdata["language"]: + genparams["language"] = formdata["language"] if not genparams: utfprint("Body Err: " + str(body)) diff --git a/otherarch/whispercpp/whisper_adapter.cpp b/otherarch/whispercpp/whisper_adapter.cpp index b666ee9f4..5b82f97be 100644 --- a/otherarch/whispercpp/whisper_adapter.cpp +++ b/otherarch/whispercpp/whisper_adapter.cpp @@ -57,8 +57,8 @@ static bool read_wav(const std::string & b64data, std::vector& pcmf32, st return false; } - if (wav.bitsPerSample != 16) { - printf("WAV file must be 16-bit\n"); + if (wav.bitsPerSample != 8 && wav.bitsPerSample != 16 && wav.bitsPerSample != 32) { + printf("WAV file must be 8-bit, 16-bit or 32-bit. Detected: %d\n",wav.bitsPerSample); drwav_uninit(&wav); return false; } @@ -67,7 +67,23 @@ static bool read_wav(const std::string & b64data, std::vector& pcmf32, st std::vector pcm16; pcm16.resize(n*wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + + if (wav.bitsPerSample == 8) { + // Handle 8-bit PCM and convert to 16-bit + std::vector pcm8(n * wav.channels); + drwav_read_pcm_frames(&wav, n, pcm8.data()); + drwav_u8_to_s16(pcm16.data(), pcm8.data(), n * wav.channels); + } else if (wav.bitsPerSample == 16) { + // Handle 16-bit PCM directly + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + } else if (wav.bitsPerSample == 32) { + // Handle 32-bit PCM and convert to 16-bit + std::vector pcm32(n * wav.channels); + drwav_read_pcm_frames_s32(&wav, n, pcm32.data()); + for (uint64_t i = 0; i < n * wav.channels; ++i) { + pcm16[i] = static_cast(pcm32[i] >> 16); // Scale down by shifting + } + } drwav_uninit(&wav); std::vector raw_pcm;