improved tool calls and whisper

This commit is contained in:
Concedo 2024-12-06 14:34:31 +08:00
parent 836c06d91a
commit e9d2332dd8
4 changed files with 243 additions and 305 deletions

View file

@ -187,6 +187,7 @@ struct whisper_generation_inputs
{ {
const char * prompt = nullptr; const char * prompt = nullptr;
const char * audio_data = nullptr; const char * audio_data = nullptr;
const bool suppress_non_speech = false;
const bool quiet = false; const bool quiet = false;
}; };
struct whisper_generation_outputs struct whisper_generation_outputs

File diff suppressed because one or more lines are too long

View file

@ -264,6 +264,7 @@ class whisper_load_model_inputs(ctypes.Structure):
class whisper_generation_inputs(ctypes.Structure): class whisper_generation_inputs(ctypes.Structure):
_fields_ = [("prompt", ctypes.c_char_p), _fields_ = [("prompt", ctypes.c_char_p),
("audio_data", ctypes.c_char_p), ("audio_data", ctypes.c_char_p),
("suppress_non_speech", ctypes.c_bool),
("quiet", ctypes.c_bool)] ("quiet", ctypes.c_bool)]
class whisper_generation_outputs(ctypes.Structure): class whisper_generation_outputs(ctypes.Structure):
@ -1236,6 +1237,7 @@ def whisper_generate(genparams):
inputs.prompt = prompt.encode("UTF-8") inputs.prompt = prompt.encode("UTF-8")
inputs.audio_data = audio_data.encode("UTF-8") inputs.audio_data = audio_data.encode("UTF-8")
inputs.quiet = is_quiet inputs.quiet = is_quiet
inputs.suppress_non_speech = genparams.get("suppress_non_speech", False)
ret = handle.whisper_generate(inputs) ret = handle.whisper_generate(inputs)
outstr = "" outstr = ""
if ret.status==1: if ret.status==1:
@ -1392,9 +1394,11 @@ def transform_genparams(genparams, api_format):
messages_string += tools_message_start messages_string += tools_message_start
# content can be a string or an array of objects # content can be a string or an array of objects
curr_content = message['content'] curr_content = message.get("content",None)
if isinstance(curr_content, str): if not curr_content:
messages_string += curr_content pass # do nothing
elif isinstance(curr_content, str):
messages_string += curr_content
elif isinstance(curr_content, list): #is an array elif isinstance(curr_content, list): #is an array
for item in curr_content: for item in curr_content:
if item['type']=="text": if item['type']=="text":

View file

@ -251,6 +251,7 @@ whisper_generation_outputs whispertype_generate(const whisper_generation_inputs
wparams.debug_mode = false; wparams.debug_mode = false;
wparams.tdrz_enable = false; wparams.tdrz_enable = false;
wparams.suppress_regex = nullptr; wparams.suppress_regex = nullptr;
wparams.suppress_non_speech_tokens = inputs.suppress_non_speech;
wparams.initial_prompt = initprompt.c_str(); wparams.initial_prompt = initprompt.c_str();
wparams.greedy.best_of = -1; wparams.greedy.best_of = -1;
wparams.beam_search.beam_size = -1; wparams.beam_search.beam_size = -1;