Auto-fix invalid JSON (#354)

This commit is contained in:
Kerem Yilmaz 2024-05-21 22:04:32 -07:00 committed by GitHub
parent df09842587
commit e6d4302d8c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,12 +1,16 @@
import base64 import base64
import json
import re import re
from typing import Any from typing import Any
import commentjson import commentjson
import litellm import litellm
import structlog
from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat from skyvern.forge.sdk.api.llm.exceptions import EmptyLLMResponseError, InvalidLLMResponseFormat
LOG = structlog.get_logger()
async def llm_messages_builder( async def llm_messages_builder(
prompt: str, prompt: str,
@ -40,7 +44,8 @@ async def llm_messages_builder(
return [{"role": "user", "content": messages}] return [{"role": "user", "content": messages}]
def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, str]: def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bool = False) -> dict[str, Any]:
content = None
try: try:
content = response.choices[0].message.content content = response.choices[0].message.content
# Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object: # Since we prefilled Anthropic response with "{" we need to add it back to the response to have a valid json object:
@ -52,9 +57,129 @@ def parse_api_response(response: litellm.ModelResponse, add_assistant_prefix: bo
raise EmptyLLMResponseError(str(response)) raise EmptyLLMResponseError(str(response))
return commentjson.loads(content) return commentjson.loads(content)
except Exception as e: except Exception as e:
if content:
LOG.warning("Failed to parse LLM response. Will retry auto-fixing the response for unescaped quotes.")
try:
return fix_and_parse_json_string(content)
except Exception as e2:
LOG.exception("Failed to auto-fix LLM response.", error=str(e2))
raise InvalidLLMResponseFormat(str(response)) from e2
raise InvalidLLMResponseFormat(str(response)) from e raise InvalidLLMResponseFormat(str(response)) from e
def fix_cutoff_json(json_string: str, error_position: int) -> dict[str, Any]:
"""
Fixes a cutoff JSON string by ignoring the last incomplete action and making it a valid JSON.
Args:
json_string (str): The cutoff JSON string to process.
error_position (int): The position of the error in the JSON string.
Returns:
str: The fixed JSON string.
"""
LOG.info("Fixing cutoff JSON string.")
try:
# Truncate the string to the error position
truncated_string = json_string[:error_position]
# Find the last valid action
last_valid_action_pos = truncated_string.rfind("},")
if last_valid_action_pos != -1:
# Remove the incomplete action
fixed_string = truncated_string[: last_valid_action_pos + 1] + "\n ]\n}"
return commentjson.loads(fixed_string)
else:
# If no valid action found, return an empty actions list
LOG.warning("No valid action found in the cutoff JSON string.")
return {"actions": []}
except Exception as e:
raise InvalidLLMResponseFormat(json_string) from e
def fix_unescaped_quotes_in_json(json_string: str) -> str:
"""
Extracts the positions of quotation marks that define the JSON structure
and the strings between them, handling unescaped quotation marks within strings.
Args:
json_string (str): The JSON-like string to process.
Returns:
str: The JSON-like string with unescaped quotation marks within strings.
"""
escape_char = "\\"
# Indices to add the escape character to. Since we're processing the string from left to right, we need to sort
# the indices in descending order to avoid index shifting.
indices_to_add_escape_char = []
in_string = False
escape = False
json_structure_chars = {",", ":", "}", "]", "{", "["}
i = 0
while i < len(json_string):
char = json_string[i]
if char == escape_char:
escape = not escape
elif char == '"' and not escape:
if in_string:
# Check if the next non-whitespace character is a JSON structure character
j = i + 1
# Skip whitespace characters
while j < len(json_string) and json_string[j].isspace():
j += 1
if j < len(json_string) and json_string[j] in json_structure_chars:
# If the next character is a JSON structure character, the quote is the end of the JSON string
in_string = False
else:
# If the next character is not a JSON structure character, the quote is part of the string
# Update the indices to add the escape character with the current index
indices_to_add_escape_char.append(i)
else:
# Start of the JSON string
in_string = True
else:
escape = False
i += 1
# Sort the indices in descending order to avoid index shifting then add the escape character to the string
if indices_to_add_escape_char:
LOG.warning("Unescaped quotes found in JSON string. Adding escape character to fix the issue.")
indices_to_add_escape_char.sort(reverse=True)
for index in indices_to_add_escape_char:
json_string = json_string[:index] + escape_char + json_string[index:]
return json_string
def fix_and_parse_json_string(json_string: str) -> dict[str, Any]:
"""
Auto-fixes a JSON string by escaping unescaped quotes and ignoring the last action if the JSON is cutoff.
Args:
json_string (str): The JSON string to process.
Returns:
dict[str, Any]: The parsed JSON object.
"""
LOG.info("Auto-fixing JSON string.")
# Escape unescaped quotes in the JSON string
json_string = fix_unescaped_quotes_in_json(json_string)
try:
# Attempt to parse the JSON string
return commentjson.loads(json_string)
except Exception:
LOG.warning("Failed to parse JSON string. Attempting to fix the JSON string.")
try:
# This seems redundant but we're doing this to get error position. Comment json doesn't return that
return json.loads(json_string)
except json.JSONDecodeError as e:
error_position = e.pos
# Try to fix the cutoff JSON string and see if it can be parsed
return fix_cutoff_json(json_string, error_position)
def replace_useless_text_around_json(input_string: str) -> str: def replace_useless_text_around_json(input_string: str) -> str:
first_occurrence_of_brace = input_string.find("{") first_occurrence_of_brace = input_string.find("{")
last_occurrence_of_brace = input_string.rfind("}") last_occurrence_of_brace = input_string.rfind("}")