mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 10:41:04 +00:00
Force Claude 3 models to output JSON object and parse it more reliably (#293)
Co-authored-by: otmane <otmanebenazzou.pro@gmail.com>
This commit is contained in:
parent
49baf471ab
commit
cf01e81ba2
4 changed files with 62 additions and 18 deletions
|
@ -81,7 +81,7 @@ class LLMAPIHandlerFactory:
|
|||
data=screenshot,
|
||||
)
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
|
@ -115,7 +115,7 @@ class LLMAPIHandlerFactory:
|
|||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
|
@ -159,7 +159,7 @@ class LLMAPIHandlerFactory:
|
|||
if not llm_config.supports_vision:
|
||||
screenshots = None
|
||||
|
||||
messages = await llm_messages_builder(prompt, screenshots)
|
||||
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
|
@ -199,7 +199,7 @@ class LLMAPIHandlerFactory:
|
|||
organization_id=step.organization_id,
|
||||
incremental_cost=llm_cost,
|
||||
)
|
||||
parsed_response = parse_api_response(response)
|
||||
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||
if step:
|
||||
await app.ARTIFACT_MANAGER.create_artifact(
|
||||
step=step,
|
||||
|
|
|
@ -55,21 +55,38 @@ if not any(
|
|||
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True))
|
||||
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True))
|
||||
LLMConfigRegistry.register_config(
|
||||
"OPENAI_GPT4_TURBO",
|
||||
LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=False, add_assistant_prefix=False),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False)
|
||||
)
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
"ANTHROPIC_CLAUDE3",
|
||||
LLMConfig(
|
||||
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3_OPUS", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
"ANTHROPIC_CLAUDE3_OPUS",
|
||||
LLMConfig(
|
||||
"anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
||||
"ANTHROPIC_CLAUDE3_SONNET",
|
||||
LLMConfig(
|
||||
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True)
|
||||
"ANTHROPIC_CLAUDE3_HAIKU",
|
||||
LLMConfig(
|
||||
"anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||
),
|
||||
)
|
||||
|
||||
if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||
|
@ -79,7 +96,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
|||
LLMConfig(
|
||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||
["AWS_REGION"],
|
||||
True,
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=True,
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
|
@ -87,7 +105,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
|||
LLMConfig(
|
||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
["AWS_REGION"],
|
||||
True,
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=True,
|
||||
),
|
||||
)
|
||||
LLMConfigRegistry.register_config(
|
||||
|
@ -95,7 +114,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
|||
LLMConfig(
|
||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||
["AWS_REGION"],
|
||||
True,
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -105,6 +125,7 @@ if SettingsManager.get_settings().ENABLE_AZURE:
|
|||
LLMConfig(
|
||||
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||
True,
|
||||
supports_vision=True,
|
||||
add_assistant_prefix=False,
|
||||
),
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ class LLMConfig:
|
|||
model_name: str
|
||||
required_env_vars: list[str]
|
||||
supports_vision: bool
|
||||
add_assistant_prefix: bool
|
||||
|
||||
def get_missing_env_vars(self) -> list[str]:
|
||||
missing_env_vars = []
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import base64
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
import commentjson
|
||||
|
@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL
|
|||
async def llm_messages_builder(
|
||||
prompt: str,
|
||||
screenshots: list[bytes] | None = None,
|
||||
add_assistant_prefix: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
|
@ -29,17 +31,37 @@ async def llm_messages_builder(
|
|||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it:
|
||||
if add_assistant_prefix:
|
||||
return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}]
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
|
||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
||||
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]:
|
||||
try:
|
||||
content = response.choices[0].message.content
|
||||
content = content.replace("```json", "")
|
||||
content = content.replace("```", "")
|
||||
# Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object:
|
||||
if add_assistant_prefix:
|
||||
content = "{" + content
|
||||
content = try_to_extract_json_from_markdown_format(content)
|
||||
content = replace_useless_text_around_json(content)
|
||||
if not content:
|
||||
raise EmptyLLMResponseError(str(response))
|
||||
return commentjson.loads(content)
|
||||
except Exception as e:
|
||||
raise InvalidLLMResponseFormat(str(response)) from e
|
||||
|
||||
|
||||
def replace_useless_text_around_json(input_string: str) -> str:
|
||||
first_occurrence_of_brace = input_string.find("{")
|
||||
last_occurrence_of_brace = input_string.rfind("}")
|
||||
return input_string[first_occurrence_of_brace : last_occurrence_of_brace + 1]
|
||||
|
||||
|
||||
def try_to_extract_json_from_markdown_format(text: str) -> str:
|
||||
pattern = r"```json\s*(.*?)\s*```"
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return text
|
||||
|
|
Loading…
Add table
Reference in a new issue