Skyvern/streamlit_app/visualizer/streamlit.py
2024-03-01 10:09:30 -08:00

383 lines
16 KiB
Python

import pandas as pd
import streamlit as st
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, TaskRequest
from streamlit_app.visualizer import styles
from streamlit_app.visualizer.api import SkyvernClient
from streamlit_app.visualizer.artifact_loader import (
read_artifact_safe,
streamlit_content_safe,
streamlit_show_recording,
)
from streamlit_app.visualizer.repository import TaskRepository
from streamlit_app.visualizer.sample_data import (
get_sample_data_extraction_goal,
get_sample_extracted_information_schema,
get_sample_navigation_goal,
get_sample_navigation_payload,
get_sample_url,
)
# Streamlit UI Configuration
st.set_page_config(layout="wide")
# Apply styles
st.markdown(styles.page_font_style, unsafe_allow_html=True)
st.markdown(styles.button_style, unsafe_allow_html=True)
# Configuration
def reset_session_state() -> None:
# Delete all the items in Session state when env or org is changed
for key in st.session_state.keys():
del st.session_state[key]
CONFIGS_DICT = st.secrets["skyvern"]["configs"]
if not CONFIGS_DICT:
raise Exception("No configuration found. Copy the values from 1P and restart the app.")
SETTINGS = {}
for config in CONFIGS_DICT:
env = config["env"]
host = config["host"]
orgs = config["orgs"]
org_dict = {org["name"]: org["cred"] for org in orgs}
SETTINGS[env] = {"host": host, "orgs": org_dict}
st.sidebar.markdown("#### **Settings**")
select_env = st.sidebar.selectbox("Environment", list(SETTINGS.keys()), on_change=reset_session_state)
select_org = st.sidebar.selectbox(
"Organization", list(SETTINGS[select_env]["orgs"].keys()), on_change=reset_session_state
)
# Initialize session state
if "client" not in st.session_state:
st.session_state.client = SkyvernClient(
base_url=SETTINGS[select_env]["host"], credentials=SETTINGS[select_env]["orgs"][select_org]
)
if "repository" not in st.session_state:
st.session_state.repository = TaskRepository(st.session_state.client)
if "task_page_number" not in st.session_state:
st.session_state.task_page_number = 1
if "selected_task" not in st.session_state:
st.session_state.selected_task = None
st.session_state.selected_task_recording_uri = None
st.session_state.task_steps = None
if "selected_step" not in st.session_state:
st.session_state.selected_step = None
st.session_state.selected_step_index = None
client = st.session_state.client
repository = st.session_state.repository
task_page_number = st.session_state.task_page_number
selected_task = st.session_state.selected_task
selected_task_recording_uri = st.session_state.selected_task_recording_uri
task_steps = st.session_state.task_steps
selected_step = st.session_state.selected_step
selected_step_index = st.session_state.selected_step_index
# Onclick handlers
def select_task(task: dict) -> None:
st.session_state.selected_task = task
st.session_state.selected_task_recording_uri = repository.get_task_recording_uri(task)
# reset step selection
st.session_state.selected_step = None
# save task's steps in session state
st.session_state.task_steps = repository.get_task_steps(task["task_id"])
if st.session_state.task_steps:
st.session_state.selected_step = st.session_state.task_steps[0]
st.session_state.selected_step_index = 0
def go_to_previous_step() -> None:
new_step_index = max(0, selected_step_index - 1)
select_step(task_steps[new_step_index])
def go_to_next_step() -> None:
new_step_index = min(len(task_steps) - 1, selected_step_index + 1)
select_step(task_steps[new_step_index])
def select_step(step: dict) -> None:
st.session_state.selected_step = step
st.session_state.selected_step_index = task_steps.index(step)
# Streamlit UI Logic
st.markdown("# **:dragon: Skyvern :dragon:**")
st.markdown(f"### **{select_env} - {select_org}**")
execute_tab, visualizer_tab = st.tabs(["Execute", "Visualizer"])
with execute_tab:
create_column, explanation_column = st.columns([1, 2])
with create_column:
with st.form("task_form"):
st.markdown("## Run a task")
# Create all the fields to create a TaskRequest object
st_url = st.text_input("URL*", value=get_sample_url(), key="url")
st_webhook_callback_url = st.text_input("Webhook Callback URL", key="webhook", placeholder="Optional")
st_navigation_goal = st.text_input(
"Navigation Goal",
key="nav_goal",
placeholder="Describe the navigation goal",
value=get_sample_navigation_goal(),
)
st_data_extraction_goal = st.text_input(
"Data Extraction Goal",
key="data_goal",
placeholder="Describe the data extraction goal",
value=get_sample_data_extraction_goal(),
)
st_navigation_payload = st.text_area(
"Navigation Payload JSON",
key="nav_payload",
placeholder='{"name": "John Doe", "email": "abc@123.com"}',
value=get_sample_navigation_payload(),
)
st_extracted_information_schema = st.text_area(
"Extracted Information Schema",
key="extracted_info_schema",
placeholder='{"quote_price": "float"}',
value=get_sample_extracted_information_schema(),
)
# Create a TaskRequest object from the form fields
task_request_body = TaskRequest(
url=st_url,
webhook_callback_url=st_webhook_callback_url,
navigation_goal=st_navigation_goal,
data_extraction_goal=st_data_extraction_goal,
proxy_location=ProxyLocation.NONE,
navigation_payload=st_navigation_payload,
extracted_information_schema=st_extracted_information_schema,
)
# Submit the form
if st.form_submit_button("Execute Task", use_container_width=True):
# Call the API to create a task
task_id = client.create_task(task_request_body)
if not task_id:
st.error("Failed to create task!")
else:
st.success("Task created successfully, task_id: " + task_id)
with explanation_column:
st.markdown("### **Task Request**")
st.markdown("#### **URL**")
st.markdown("The starting URL for the task.")
st.markdown("#### **Webhook Callback URL**")
st.markdown("The URL to call with the results when the task is completed.")
st.markdown("#### **Navigation Goal**")
st.markdown("The user's goal for the task. Nullable if the task is only for data extraction.")
st.markdown("#### **Data Extraction Goal**")
st.markdown("The user's goal for data extraction. Nullable if the task is only for navigation.")
st.markdown("#### **Navigation Payload**")
st.markdown("The user's details needed to achieve the task. AI will use this information as needed.")
st.markdown("#### **Extracted Information Schema**")
st.markdown("The requested schema of the extracted information for data extraction goal.")
with visualizer_tab:
task_id_input = st.text_input("task_id", value="")
def search_task() -> None:
if not task_id_input:
return
task = repository.get_task(task_id_input)
if task:
select_task(task)
else:
st.error(f"Task with id {task_id_input} not found.")
st.button("search task", on_click=search_task)
col_tasks, _, col_steps, _, col_artifacts = st.columns([4, 1, 6, 1, 18])
col_tasks.markdown(f"#### Tasks")
col_steps.markdown(f"#### Steps")
col_artifacts.markdown("#### Artifacts")
tasks_response = repository.get_tasks(task_page_number)
if "error" in tasks_response:
st.write(tasks_response)
# Display tasks in sidebar for selection
tasks = {task["task_id"]: task for task in tasks_response}
task_id_buttons = {
task_id: col_tasks.button(
f"{task_id}",
on_click=select_task,
args=(task,),
use_container_width=True,
type="primary" if selected_task and task_id == selected_task["task_id"] else "secondary",
)
for task_id, task in tasks.items()
}
# Display pagination buttons
task_page_prev, _, show_task_page_number, _, task_page_next = col_tasks.columns([1, 1, 1, 1, 1])
show_task_page_number.button(str(task_page_number), disabled=True)
if task_page_next.button("\>"):
st.session_state.task_page_number += 1
if task_page_prev.button("\<", disabled=task_page_number == 1):
st.session_state.task_page_number = max(1, st.session_state.task_page_number - 1)
(
tab_task,
tab_step,
tab_recording,
tab_screenshot,
tab_post_action_screenshot,
tab_id_to_xpath,
tab_element_tree,
tab_element_tree_trimmed,
tab_llm_prompt,
tab_llm_request,
tab_llm_response_parsed,
tab_llm_response_raw,
tab_html,
) = col_artifacts.tabs(
[
":green[Task]",
":blue[Step]",
":violet[Recording]",
":rainbow[Screenshot]",
":rainbow[Action Screenshots]",
":red[ID -> XPath]",
":orange[Element Tree]",
":blue[Element Tree (Trimmed)]",
":yellow[LLM Prompt]",
":green[LLM Request]",
":blue[LLM Response (Parsed)]",
":violet[LLM Response (Raw)]",
":rainbow[Html (Raw)]",
]
)
tab_task_details, tab_task_steps, tab_task_action_results = tab_task.tabs(["Details", "Steps", "Action Results"])
if selected_task:
tab_task_details.json(selected_task)
if selected_task_recording_uri:
streamlit_show_recording(tab_recording, selected_task_recording_uri)
if task_steps:
col_steps_prev, _, col_steps_next = col_steps.columns([3, 1, 3])
col_steps_prev.button(
"prev", on_click=go_to_previous_step, key="previous_step_button", use_container_width=True
)
col_steps_next.button("next", on_click=go_to_next_step, key="next_step_button", use_container_width=True)
step_id_buttons = {
step["step_id"]: col_steps.button(
f"{step['order']} - {step['retry_index']} - {step['step_id']}",
on_click=select_step,
args=(step,),
use_container_width=True,
type="primary" if selected_step and step["step_id"] == selected_step["step_id"] else "secondary",
)
for step in task_steps
}
df = pd.json_normalize(task_steps)
tab_task_steps.dataframe(df, use_container_width=True, height=1000)
task_action_results = []
for step in task_steps:
output = step.get("output")
step_id = step["step_id"]
if output:
step_action_results = output.get("action_results", [])
for action_result in step_action_results:
task_action_results.append(
{
"step_id": step_id,
"order": step["order"],
"retry_index": step["retry_index"],
**action_result,
}
)
df = pd.json_normalize(task_action_results)
df = df.reindex(sorted(df.columns), axis=1)
tab_task_action_results.dataframe(df, use_container_width=True, height=1000)
if selected_step:
tab_step.json(selected_step)
artifacts_response = repository.get_artifacts(selected_task["task_id"], selected_step["step_id"])
split_artifact_uris = [artifact["uri"].split("/") for artifact in artifacts_response]
file_name_to_uris = {split_uri[-1]: "/".join(split_uri) for split_uri in split_artifact_uris}
for file_name, uri in file_name_to_uris.items():
file_name = file_name.lower()
if file_name.endswith("screenshot_llm.png") or file_name.endswith("screenshot.png"):
streamlit_content_safe(
tab_screenshot,
tab_screenshot.image,
read_artifact_safe(uri, is_image=True),
"No screenshot available.",
use_column_width=True,
)
elif file_name.endswith("screenshot_action.png"):
streamlit_content_safe(
tab_post_action_screenshot,
tab_post_action_screenshot.image,
read_artifact_safe(uri, is_image=True),
"No action screenshot available.",
use_column_width=True,
)
elif file_name.endswith("id_xpath_map.json"):
streamlit_content_safe(
tab_id_to_xpath, tab_id_to_xpath.json, read_artifact_safe(uri), "No ID -> XPath map available."
)
elif file_name.endswith("tree.json"):
streamlit_content_safe(
tab_element_tree,
tab_element_tree.json,
read_artifact_safe(uri),
"No element tree available.",
)
elif file_name.endswith("tree_trimmed.json"):
streamlit_content_safe(
tab_element_tree_trimmed,
tab_element_tree_trimmed.json,
read_artifact_safe(uri),
"No element tree trimmed available.",
)
elif file_name.endswith("llm_prompt.txt"):
content = read_artifact_safe(uri)
# this is a hacky way to call this generic method to get it working with st.text_area
streamlit_content_safe(
tab_llm_prompt,
tab_llm_prompt.text_area,
content,
"No LLM prompt available.",
value=content,
height=1000,
label_visibility="collapsed",
)
# tab_llm_prompt.text_area("collapsed", value=content, label_visibility="collapsed", height=1000)
elif file_name.endswith("llm_request.json"):
streamlit_content_safe(
tab_llm_request, tab_llm_request.json, read_artifact_safe(uri), "No LLM request available."
)
elif file_name.endswith("llm_response_parsed.json"):
streamlit_content_safe(
tab_llm_response_parsed,
tab_llm_response_parsed.json,
read_artifact_safe(uri),
"No parsed LLM response available.",
)
elif file_name.endswith("llm_response.json"):
streamlit_content_safe(
tab_llm_response_raw,
tab_llm_response_raw.json,
read_artifact_safe(uri),
"No raw LLM response available.",
)
elif file_name.endswith("html_scrape.html"):
streamlit_content_safe(tab_html, tab_html.text, read_artifact_safe(uri), "No html available.")
elif file_name.endswith("html_action.html"):
streamlit_content_safe(tab_html, tab_html.text, read_artifact_safe(uri), "No html available.")
else:
st.write(f"Artifact {file_name} not supported.")