Force Claude 3 models to output JSON object and parse it more reliably (#293)

Co-authored-by: otmane <otmanebenazzou.pro@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-05-10 00:51:12 -07:00 committed by GitHub
parent 49baf471ab
commit cf01e81ba2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 62 additions and 18 deletions

View file

@ -81,7 +81,7 @@ class LLMAPIHandlerFactory:
data=screenshot, data=screenshot,
) )
messages = await llm_messages_builder(prompt, screenshots) messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step: if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
@ -115,7 +115,7 @@ class LLMAPIHandlerFactory:
organization_id=step.organization_id, organization_id=step.organization_id,
incremental_cost=llm_cost, incremental_cost=llm_cost,
) )
parsed_response = parse_api_response(response) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step: if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
@ -159,7 +159,7 @@ class LLMAPIHandlerFactory:
if not llm_config.supports_vision: if not llm_config.supports_vision:
screenshots = None screenshots = None
messages = await llm_messages_builder(prompt, screenshots) messages = await llm_messages_builder(prompt, screenshots, llm_config.add_assistant_prefix)
if step: if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,
@ -199,7 +199,7 @@ class LLMAPIHandlerFactory:
organization_id=step.organization_id, organization_id=step.organization_id,
incremental_cost=llm_cost, incremental_cost=llm_cost,
) )
parsed_response = parse_api_response(response) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix)
if step: if step:
await app.ARTIFACT_MANAGER.create_artifact( await app.ARTIFACT_MANAGER.create_artifact(
step=step, step=step,

View file

@ -55,21 +55,38 @@ if not any(
if SettingsManager.get_settings().ENABLE_OPENAI: if SettingsManager.get_settings().ENABLE_OPENAI:
LLMConfigRegistry.register_config("OPENAI_GPT4_TURBO", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True)) LLMConfigRegistry.register_config(
LLMConfigRegistry.register_config("OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], True)) "OPENAI_GPT4_TURBO",
LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=False, add_assistant_prefix=False),
)
LLMConfigRegistry.register_config(
"OPENAI_GPT4V", LLMConfig("gpt-4-turbo", ["OPENAI_API_KEY"], supports_vision=True, add_assistant_prefix=False)
)
if SettingsManager.get_settings().ENABLE_ANTHROPIC: if SettingsManager.get_settings().ENABLE_ANTHROPIC:
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) "ANTHROPIC_CLAUDE3",
LLMConfig(
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
),
) )
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_OPUS", LLMConfig("anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], True) "ANTHROPIC_CLAUDE3_OPUS",
LLMConfig(
"anthropic/claude-3-opus-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
),
) )
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_SONNET", LLMConfig("anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], True) "ANTHROPIC_CLAUDE3_SONNET",
LLMConfig(
"anthropic/claude-3-sonnet-20240229", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
),
) )
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
"ANTHROPIC_CLAUDE3_HAIKU", LLMConfig("anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], True) "ANTHROPIC_CLAUDE3_HAIKU",
LLMConfig(
"anthropic/claude-3-haiku-20240307", ["ANTHROPIC_API_KEY"], supports_vision=True, add_assistant_prefix=True
),
) )
if SettingsManager.get_settings().ENABLE_BEDROCK: if SettingsManager.get_settings().ENABLE_BEDROCK:
@ -79,7 +96,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
LLMConfig( LLMConfig(
"bedrock/anthropic.claude-3-opus-20240229-v1:0", "bedrock/anthropic.claude-3-opus-20240229-v1:0",
["AWS_REGION"], ["AWS_REGION"],
True, supports_vision=True,
add_assistant_prefix=True,
), ),
) )
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
@ -87,7 +105,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
LLMConfig( LLMConfig(
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
["AWS_REGION"], ["AWS_REGION"],
True, supports_vision=True,
add_assistant_prefix=True,
), ),
) )
LLMConfigRegistry.register_config( LLMConfigRegistry.register_config(
@ -95,7 +114,8 @@ if SettingsManager.get_settings().ENABLE_BEDROCK:
LLMConfig( LLMConfig(
"bedrock/anthropic.claude-3-haiku-20240307-v1:0", "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
["AWS_REGION"], ["AWS_REGION"],
True, supports_vision=True,
add_assistant_prefix=True,
), ),
) )
@ -105,6 +125,7 @@ if SettingsManager.get_settings().ENABLE_AZURE:
LLMConfig( LLMConfig(
f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}", f"azure/{SettingsManager.get_settings().AZURE_DEPLOYMENT}",
["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"], ["AZURE_DEPLOYMENT", "AZURE_API_KEY", "AZURE_API_BASE", "AZURE_API_VERSION"],
True, supports_vision=True,
add_assistant_prefix=False,
), ),
) )

View file

@ -10,6 +10,7 @@ class LLMConfig:
model_name: str model_name: str
required_env_vars: list[str] required_env_vars: list[str]
supports_vision: bool supports_vision: bool
add_assistant_prefix: bool
def get_missing_env_vars(self) -> list[str]: def get_missing_env_vars(self) -> list[str]:
missing_env_vars = [] missing_env_vars = []

View file

@ -1,4 +1,5 @@
import base64 import base64
import re
from typing import Any from typing import Any
import commentjson import commentjson
@ -10,6 +11,7 @@ from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidL
async def llm_messages_builder( async def llm_messages_builder(
prompt: str, prompt: str,
screenshots: list[bytes] | None = None, screenshots: list[bytes] | None = None,
add_assistant_prefix: bool = False,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
messages: list[dict[str, Any]] = [ messages: list[dict[str, Any]] = [
{ {
@ -29,17 +31,37 @@ async def llm_messages_builder(
}, },
} }
) )
# Anthropic models seems to struggle to always output a valid json object so we need to prefill the response to force it:
if add_assistant_prefix:
return [{"role": "user", "content": messages}, {"role": "assistant", "content": "{"}]
return [{"role": "user", "content": messages}] return [{"role": "user", "content": messages}]
def parse_api_response(response: litellm.ModelResponse) -> dict[str, str]: def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]:
try: try:
content = response.choices[0].message.content content = response.choices[0].message.content
content = content.replace("```json", "") # Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object:
content = content.replace("```", "") if add_assistant_prefix:
content = "{" + content
content = try_to_extract_json_from_markdown_format(content)
content = replace_useless_text_around_json(content)
if not content: if not content:
raise EmptyLLMResponseError(str(response)) raise EmptyLLMResponseError(str(response))
return commentjson.loads(content) return commentjson.loads(content)
except Exception as e: except Exception as e:
raise InvalidLLMResponseFormat(str(response)) from e raise InvalidLLMResponseFormat(str(response)) from e
def replace_useless_text_around_json(input_string: str) -> str:
first_occurrence_of_brace = input_string.find("{")
last_occurrence_of_brace = input_string.rfind("}")
return input_string[first_occurrence_of_brace : last_occurrence_of_brace + 1]
def try_to_extract_json_from_markdown_format(text: str) -> str:
pattern = r"```json\s*(.*?)\s*```"
match = re.search(pattern, text, re.DOTALL)
if match:
return match.group(1)
else:
return text