mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2025-09-02 10:41:04 +00:00
Force Claude 3 models to output JSON object and parse it more reliably (#293)
Co-authored-by: otmane <otmanebenazzou.pro@gmail.com>
This commit is contained in:
parent
49baf471ab
commit
cf01e81ba2
4 changed files with 62 additions and 18 deletions
|
@ -81,7 +81,7 @@ class LLMAPIHandlerFactory:
|
||||||
data=screenshot,
|
data=screenshot,
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = await llm_messages_builder(prompt, screenshots)
|
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -115,7 +115,7 @@ class LLMAPIHandlerFactory:
|
||||||
organization_id=step.organization_id,
|
organization_id=step.organization_id,
|
||||||
incremental_cost=llm_cost,
|
incremental_cost=llm_cost,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -159,7 +159,7 @@ class LLMAPIHandlerFactory:
|
||||||
if not llm_config.supports_vision:
|
if not llm_config.supports_vision:
|
||||||
screenshots = None
|
screenshots = None
|
||||||
|
|
||||||
messages = await llm_messages_builder(prompt, screenshots)
|
messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
step=step,
|
step=step,
|
||||||
|
@ -199,7 +199,7 @@ class LLMAPIHandlerFactory:
|
||||||
organization_id=step.organization_id,
|
organization_id=step.organization_id,
|
||||||
incremental_cost=llm_cost,
|
incremental_cost=llm_cost,
|
||||||
)
|
)
|
||||||
parsed_response = parse_api_response(response)
|
parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
|
||||||
if step:
|
if step:
|
||||||
await app.ARTIFACT_MANAGER.create_artifact(
|
await app.ARTIFACT_MANAGER.create_artifact(
|
||||||
step=step,
|
step=step,
|
||||||
|
|
|
@ -55,21 +55,38 @@ if not any(
|
||||||
|
|
||||||
|
|
||||||
if SettingsManager.get_settings().ENABLE_OPENAI:
|
if SettingsManager.get_settings().ENABLE_OPENAI:
|
||||||
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True))
|
LLMConfigRegistry.register_config(
|
||||||
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True))
|
"OPENAI_GPT4_TURBO",
|
||||||
|
LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=False, add_assistant_prefix=False),
|
||||||
|
)
|
||||||
|
LLMConfigRegistry.register_config(
|
||||||
|
"OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False)
|
||||||
|
)
|
||||||
|
|
||||||
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
if SettingsManager.get_settings().ENABLE_ANTHROPIC:
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
"ANTHROPIC_CLAUDE3",
|
||||||
|
LLMConfig(
|
||||||
|
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
"ANTHROPIC_CLAUDE3_OPUS", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True)
|
"ANTHROPIC_CLAUDE3_OPUS",
|
||||||
|
LLMConfig(
|
||||||
|
"anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True)
|
"ANTHROPIC_CLAUDE3_SONNET",
|
||||||
|
LLMConfig(
|
||||||
|
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True)
|
"ANTHROPIC_CLAUDE3_HAIKU",
|
||||||
|
LLMConfig(
|
||||||
|
"anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if SettingsManager.get_settings().ENABLE_BEDROCK:
|
if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
|
@ -79,7 +96,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||||
["AWS_REGION"],
|
["AWS_REGION"],
|
||||||
True,
|
supports_vision=True,
|
||||||
|
add_assistant_prefix=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
|
@ -87,7 +105,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
["AWS_REGION"],
|
["AWS_REGION"],
|
||||||
True,
|
supports_vision=True,
|
||||||
|
add_assistant_prefix=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
LLMConfigRegistry.register_config(
|
LLMConfigRegistry.register_config(
|
||||||
|
@ -95,7 +114,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
["AWS_REGION"],
|
["AWS_REGION"],
|
||||||
True,
|
supports_vision=True,
|
||||||
|
add_assistant_prefix=True,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -105,6 +125,7 @@ if SettingsManager.get_settings().ENABLE_AZURE:
|
||||||
LLMConfig(
|
LLMConfig(
|
||||||
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
|
||||||
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
|
||||||
True,
|
supports_vision=True,
|
||||||
|
add_assistant_prefix=False,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,6 +10,7 @@ class LLMConfig:
|
||||||
model_name: str
|
model_name: str
|
||||||
required_env_vars: list[str]
|
required_env_vars: list[str]
|
||||||
supports_vision: bool
|
supports_vision: bool
|
||||||
|
add_assistant_prefix: bool
|
||||||
|
|
||||||
def get_missing_env_vars(self) -> list[str]:
|
def get_missing_env_vars(self) -> list[str]:
|
||||||
missing_env_vars = []
|
missing_env_vars = []
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import base64
|
import base64
|
||||||
|
import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import commentjson
|
import commentjson
|
||||||
|
@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL
|
||||||
async def llm_messages_builder(
|
async def llm_messages_builder(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
screenshots: list[bytes] | None = None,
|
screenshots: list[bytes] | None = None,
|
||||||
|
add_assistant_prefix: bool = False,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
messages: list[dict[str, Any]] = [
|
messages: list[dict[str, Any]] = [
|
||||||
{
|
{
|
||||||
|
@ -29,17 +31,37 @@ async def llm_messages_builder(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
# Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it:
|
||||||
|
if add_assistant_prefix:
|
||||||
|
return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}]
|
||||||
return [{"role": "user", "content": messages}]
|
return [{"role": "user", "content": messages}]
|
||||||
|
|
||||||
|
|
||||||
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]:
|
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]:
|
||||||
try:
|
try:
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
content = content.replace("```json", "")
|
# Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object:
|
||||||
content = content.replace("```", "")
|
if add_assistant_prefix:
|
||||||
|
content = "{" + content
|
||||||
|
content = try_to_extract_json_from_markdown_format(content)
|
||||||
|
content = replace_useless_text_around_json(content)
|
||||||
if not content:
|
if not content:
|
||||||
raise EmptyLLMResponseError(str(response))
|
raise EmptyLLMResponseError(str(response))
|
||||||
return commentjson.loads(content)
|
return commentjson.loads(content)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise InvalidLLMResponseFormat(str(response)) from e
|
raise InvalidLLMResponseFormat(str(response)) from e
|
||||||
|
|
||||||
|
|
||||||
|
def replace_useless_text_around_json(input_string: str) -> str:
|
||||||
|
first_occurrence_of_brace = input_string.find("{")
|
||||||
|
last_occurrence_of_brace = input_string.rfind("}")
|
||||||
|
return input_string[first_occurrence_of_brace : last_occurrence_of_brace + 1]
|
||||||
|
|
||||||
|
|
||||||
|
def try_to_extract_json_from_markdown_format(text: str) -> str:
|
||||||
|
pattern = r"```json\s*(.*?)\s*```"
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
if match:
|
||||||
|
return match.group(1)
|
||||||
|
else:
|
||||||
|
return text
|
||||||
|
|
Loading…
Add table
Reference in a new issue