Skyvern/skyvern/forge/sdk/routes/workflow_copilot.py

1224 lines
51 KiB
Python

import asyncio
import time
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import structlog
import yaml
from fastapi import Depends, HTTPException, Request, status
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from skyvern.config import settings
from skyvern.constants import DEFAULT_LOGIN_PROMPT
from skyvern.forge import app
from skyvern.forge.prompts import prompt_engine
from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
from skyvern.forge.sdk.copilot.agent import run_copilot_agent
from skyvern.forge.sdk.copilot.output_utils import truncate_output
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream
from skyvern.forge.sdk.routes.routers import base_router
from skyvern.forge.sdk.schemas.organizations import Organization
from skyvern.forge.sdk.schemas.workflow_copilot import (
WorkflowCopilotChatHistoryMessage,
WorkflowCopilotChatHistoryResponse,
WorkflowCopilotChatMessage,
WorkflowCopilotChatRequest,
WorkflowCopilotChatSender,
WorkflowCopilotClearProposedWorkflowRequest,
WorkflowCopilotProcessingUpdate,
WorkflowCopilotStreamErrorUpdate,
WorkflowCopilotStreamMessageType,
WorkflowCopilotStreamResponseUpdate,
WorkflowYAMLConversionRequest,
WorkflowYAMLConversionResponse,
)
from skyvern.forge.sdk.services import org_auth_service
from skyvern.forge.sdk.workflow.exceptions import BaseWorkflowHTTPException
from skyvern.forge.sdk.workflow.models.parameter import ParameterType
from skyvern.forge.sdk.workflow.models.workflow import Workflow
from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition
from skyvern.schemas.workflows import (
BlockYAML,
BranchConditionYAML,
ConditionalBlockYAML,
ForLoopBlockYAML,
LoginBlockYAML,
WorkflowCreateYAMLRequest,
WorkflowDefinitionYAML,
)
from skyvern.utils.strings import escape_code_fences
from skyvern.utils.yaml_loader import safe_load_no_dates
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
CHAT_HISTORY_CONTEXT_MESSAGES = 10
LOG = structlog.get_logger()
@dataclass(frozen=True)
class RunInfo:
block_label: str | None
block_type: str
block_status: str | None
failure_reason: str | None
html: str | None
# New-copilot richer block shape (used only from the ENABLE_WORKFLOW_COPILOT_V2
# dispatch path). Kept side-by-side with the old RunInfo so the old-copilot
# body stays untouched; consolidation is SKY-8916's job.
@dataclass(frozen=True)
class BlockRunInfo:
block_label: str | None
block_type: str
block_status: str | None
failure_reason: str | None
output: str | None
def _should_restore_persisted_workflow(auto_accept: bool | None, agent_result: object | None) -> bool:
"""Return True when a persisted draft should be rolled back.
SKY-9143: when the agent decided not to ship a proposal this turn
(``updated_workflow is None``) but ``_update_workflow`` already committed
a YAML to ``workflow_definition``, we must restore the original even under
``auto_accept=True`` — otherwise an unverified edit becomes the live
workflow silently.
"""
if not bool(getattr(agent_result, "workflow_was_persisted", False)):
return False
if getattr(agent_result, "updated_workflow", None) is None:
return True
return auto_accept is not True
async def _restore_workflow_definition(original_workflow: Workflow | None, organization_id: str) -> None:
"""Roll the workflow back to ``original_workflow``.
Unconditional restore helper. Callers must first gate this with
``_should_restore_persisted_workflow`` so success, disconnect, and exception
paths all apply the same rollback rule: only restore when the user did not
opt into auto-accept AND the agent loop actually persisted a mid-request
draft.
"""
if not original_workflow:
return
try:
await app.WORKFLOW_SERVICE.update_workflow_definition(
workflow_id=original_workflow.workflow_id,
organization_id=organization_id,
title=original_workflow.title,
description=original_workflow.description,
workflow_definition=original_workflow.workflow_definition,
)
except Exception:
LOG.warning(
"Failed to restore original workflow",
workflow_id=original_workflow.workflow_id,
exc_info=True,
)
async def _get_debug_artifact(organization_id: str, workflow_run_id: str) -> Artifact | None:
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
run_id=workflow_run_id, organization_id=organization_id, artifact_types=[ArtifactType.VISIBLE_ELEMENTS_TREE]
)
return artifacts[0] if isinstance(artifacts, list) and artifacts else None
async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None) -> RunInfo | None:
if not workflow_run_id:
return None
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not blocks:
return None
block = blocks[0]
artifact = await _get_debug_artifact(organization_id, workflow_run_id)
if artifact:
artifact_bytes = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
html = artifact_bytes.decode("utf-8") if artifact_bytes else None
else:
html = None
return RunInfo(
block_label=block.label,
block_type=block.block_type.name,
block_status=block.status,
failure_reason=block.failure_reason,
html=html,
)
async def _get_new_copilot_block_infos(
organization_id: str, workflow_run_id: str | None
) -> tuple[list[BlockRunInfo], str | None]:
"""Variant of _get_debug_run_info used by the ENABLE_WORKFLOW_COPILOT_V2 path.
Returns a list of per-block records plus the run's VISIBLE_ELEMENTS_TREE
HTML artifact. Coexists with _get_debug_run_info which returns the
simpler single-block shape used by the old-copilot path.
"""
if not workflow_run_id:
return [], None
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
workflow_run_id=workflow_run_id, organization_id=organization_id
)
if not blocks:
return [], None
block_infos: list[BlockRunInfo] = []
for block in blocks:
block_type_name = block.block_type.name if hasattr(block.block_type, "name") else str(block.block_type)
block_infos.append(
BlockRunInfo(
block_label=block.label,
block_type=block_type_name,
block_status=block.status,
failure_reason=block.failure_reason,
output=truncate_output(getattr(block, "output", None)),
)
)
artifact = await _get_debug_artifact(organization_id, workflow_run_id)
html: str | None = None
if artifact:
artifact_bytes = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
html = artifact_bytes.decode("utf-8") if artifact_bytes else None
return block_infos, html
def _format_chat_history(chat_history: list[WorkflowCopilotChatHistoryMessage]) -> str:
chat_history_text = ""
if chat_history:
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
chat_history_text = "\n".join(history_lines)
return chat_history_text
def _parse_llm_response(llm_response: dict[str, Any] | Any) -> Any:
if isinstance(llm_response, dict) and "output" in llm_response:
action_data = llm_response["output"]
else:
action_data = llm_response
if not isinstance(action_data, dict):
LOG.error(
"LLM response is not valid JSON",
response_type=type(action_data).__name__,
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Invalid response from LLM",
)
return action_data
async def copilot_call_llm(
stream: EventSourceStream,
organization_id: str,
chat_request: WorkflowCopilotChatRequest,
chat_history: list[WorkflowCopilotChatHistoryMessage],
global_llm_context: str | None,
debug_run_info_text: str,
) -> tuple[str, Workflow | None, str | None]:
chat_history_text = _format_chat_history(chat_history)
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
# Render system prompt (trusted content only, security rules injected via AgentFunction)
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
system_prompt = prompt_engine.load_prompt(
template="workflow-copilot-system",
workflow_knowledge_base=workflow_knowledge_base,
current_datetime=datetime.now(timezone.utc).isoformat(),
security_rules=security_rules,
)
# Render user prompt (untrusted content, each variable in code fences)
# Escape triple backticks to prevent code fence breakout
user_prompt = prompt_engine.load_prompt(
template="workflow-copilot-user",
workflow_yaml=escape_code_fences(chat_request.workflow_yaml or ""),
user_message=escape_code_fences(chat_request.message),
chat_history=escape_code_fences(chat_history_text),
global_llm_context=escape_code_fences(global_llm_context or ""),
debug_run_info=escape_code_fences(debug_run_info_text),
)
LOG.info(
"Calling LLM",
workflow_permanent_id=chat_request.workflow_permanent_id,
workflow_id=chat_request.workflow_id,
user_message_len=len(chat_request.message),
user_message=chat_request.message,
workflow_yaml_len=len(chat_request.workflow_yaml or ""),
workflow_yaml=chat_request.workflow_yaml or "",
chat_history_len=len(chat_history_text),
chat_history=chat_history_text,
global_llm_context_len=len(global_llm_context or ""),
global_llm_context=global_llm_context or "",
workflow_knowledge_base_len=len(workflow_knowledge_base),
debug_run_info_len=len(debug_run_info_text),
system_prompt_len=len(system_prompt),
user_prompt_len=len(user_prompt),
)
llm_api_handler = (
await get_llm_handler_for_prompt_type("workflow-copilot", chat_request.workflow_permanent_id, organization_id)
or app.LLM_API_HANDLER
)
llm_start_time = time.monotonic()
llm_response = await llm_api_handler(
prompt=user_prompt,
prompt_name="workflow-copilot",
organization_id=organization_id,
system_prompt=system_prompt,
)
LOG.info(
"LLM response",
workflow_permanent_id=chat_request.workflow_permanent_id,
workflow_id=chat_request.workflow_id,
duration_seconds=time.monotonic() - llm_start_time,
user_message_len=len(chat_request.message),
workflow_yaml_len=len(chat_request.workflow_yaml or ""),
chat_history_len=len(chat_history_text),
global_llm_context_len=len(global_llm_context or ""),
debug_run_info_len=len(debug_run_info_text),
workflow_knowledge_base_len=len(workflow_knowledge_base),
llm_response_len=len(llm_response),
llm_response=llm_response,
)
action_data = _parse_llm_response(llm_response)
action_type = action_data.get("type")
user_response_value = action_data.get("user_response")
if user_response_value is None:
user_response = "I received your request but I'm not sure how to help. Could you rephrase?"
else:
user_response = str(user_response_value)
LOG.info(
"LLM response received",
workflow_permanent_id=chat_request.workflow_permanent_id,
workflow_id=chat_request.workflow_id,
organization_id=organization_id,
action_type=action_type,
)
global_llm_context = action_data.get("global_llm_context")
if global_llm_context is not None:
global_llm_context = str(global_llm_context)
if action_type == "REPLACE_WORKFLOW":
llm_workflow_yaml = action_data.get("workflow_yaml", "")
try:
updated_workflow = _process_workflow_yaml(
workflow_id=chat_request.workflow_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
organization_id=organization_id,
workflow_yaml=llm_workflow_yaml,
)
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Validating workflow definition...",
timestamp=datetime.now(timezone.utc),
)
)
corrected_workflow_yaml = await _auto_correct_workflow_yaml(
llm_api_handler=llm_api_handler,
organization_id=organization_id,
user_response=user_response,
workflow_yaml=llm_workflow_yaml,
chat_history=chat_history,
global_llm_context=global_llm_context,
debug_run_info_text=debug_run_info_text,
error=e,
)
updated_workflow = _process_workflow_yaml(
workflow_id=chat_request.workflow_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
organization_id=organization_id,
workflow_yaml=corrected_workflow_yaml,
)
return user_response, updated_workflow, global_llm_context
elif action_type == "REPLY":
return user_response, None, global_llm_context
elif action_type == "ASK_QUESTION":
return user_response, None, global_llm_context
else:
LOG.error(
"Unknown action type from LLM",
organization_id=organization_id,
action_type=action_type,
)
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
async def _auto_correct_workflow_yaml(
llm_api_handler: LLMAPIHandler,
organization_id: str,
user_response: str,
workflow_yaml: str,
chat_history: list[WorkflowCopilotChatHistoryMessage],
global_llm_context: str | None,
debug_run_info_text: str,
error: Exception,
) -> str:
failure_reason = f"{error.__class__.__name__}: {error}"
new_chat_history = chat_history[:]
new_chat_history.append(
WorkflowCopilotChatHistoryMessage(
sender=WorkflowCopilotChatSender.AI,
content=user_response,
created_at=datetime.now(timezone.utc),
)
)
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
system_prompt = prompt_engine.load_prompt(
template="workflow-copilot-system",
workflow_knowledge_base=workflow_knowledge_base,
current_datetime=datetime.now(timezone.utc).isoformat(),
security_rules=security_rules,
)
user_prompt = prompt_engine.load_prompt(
template="workflow-copilot-user",
workflow_yaml=escape_code_fences(workflow_yaml),
user_message=escape_code_fences(f"Workflow YAML parsing failed, please fix it: {failure_reason}"),
chat_history=escape_code_fences(_format_chat_history(new_chat_history)),
global_llm_context=escape_code_fences(global_llm_context or ""),
debug_run_info=escape_code_fences(debug_run_info_text),
)
llm_start_time = time.monotonic()
llm_response = await llm_api_handler(
prompt=user_prompt,
prompt_name="workflow-copilot",
organization_id=organization_id,
system_prompt=system_prompt,
)
LOG.info(
"Auto-correction LLM response",
duration_seconds=time.monotonic() - llm_start_time,
llm_response_len=len(llm_response),
llm_response=llm_response,
)
action_data = _parse_llm_response(llm_response)
return action_data.get("workflow_yaml", workflow_yaml)
def _collect_reachable(
start_label: str,
label_to_block: dict[str, BlockYAML],
reachable: set[str],
) -> None:
"""Walk the next_block_label chain from start_label, collecting all reachable labels.
For conditional blocks, also follows branch target chains recursively.
The ``current not in reachable`` loop guard means the main-chain walk
stops early if we hit a node already collected via a branch recursion.
This is correct — those downstream nodes and their successors are
already in ``reachable`` — but callers should be aware of the coupling.
"""
current: str | None = start_label
while current and current in label_to_block and current not in reachable:
reachable.add(current)
block = label_to_block[current]
if isinstance(block, ConditionalBlockYAML):
for branch in block.branch_conditions:
if branch.next_block_label and branch.next_block_label not in reachable:
_collect_reachable(branch.next_block_label, label_to_block, reachable)
current = block.next_block_label
def _break_cycles(
start_label: str,
label_to_block: dict[str, BlockYAML],
) -> bool:
"""Detect and break circular references in the block chain using DFS.
Uses a recursion stack to distinguish true back-edges (cycles) from merge
points (two branches converging on the same block). When a back-edge is
found the offending ``next_block_label`` is set to ``None``, breaking the
cycle. Handles both the main chain and conditional branch chains.
Note: this function operates on a single level of blocks. It does **not**
recurse into ``ForLoopBlockYAML.loop_blocks``; nested loops are handled
by the recursive ``_repair_next_block_label_chain`` call in Phase 3.
Returns True if at least one cycle was broken.
"""
visited: set[str] = set()
rec_stack: set[str] = set()
found_cycle = False
def _follow_edge(target: str | None, edge_owner: BlockYAML | BranchConditionYAML, parent_label: str) -> None:
"""Follow an edge to *target*. *edge_owner* is the object whose
``next_block_label`` will be set to ``None`` when the target forms a
back-edge. *parent_label* is the block label that owns this edge
for logging."""
nonlocal found_cycle
if not target or target not in label_to_block:
return
if target in rec_stack:
is_branch = hasattr(edge_owner, "criteria")
LOG.warning(
"Copilot produced circular block chain, breaking cycle",
cycle_target=target,
broken_at=parent_label,
is_branch_condition=is_branch,
branch_expression=getattr(getattr(edge_owner, "criteria", None), "expression", None),
)
edge_owner.next_block_label = None
found_cycle = True
return
if target in visited:
return # merge point — not a cycle
_dfs(target)
def _dfs(label: str) -> None:
visited.add(label)
rec_stack.add(label)
block = label_to_block[label]
if isinstance(block, ConditionalBlockYAML):
for branch in block.branch_conditions:
_follow_edge(branch.next_block_label, branch, label)
_follow_edge(block.next_block_label, block, label)
rec_stack.discard(label)
if start_label in label_to_block:
_dfs(start_label)
return found_cycle
def _find_terminal_label(
start_label: str,
label_to_block: dict[str, BlockYAML],
all_labels: set[str],
) -> str | None:
"""Find the terminal block by walking the main chain from start_label."""
visited: set[str] = set()
current: str | None = start_label
while current and current in label_to_block and current not in visited:
visited.add(current)
block = label_to_block[current]
if block.next_block_label is None or block.next_block_label not in all_labels:
return current
current = block.next_block_label
return None
def _order_orphaned_blocks(
orphaned_labels: set[str],
label_to_block: dict[str, BlockYAML],
all_labels: set[str],
blocks: list[BlockYAML],
) -> list[str]:
"""Order orphaned blocks by following their internal next_block_label chains.
Multiple disconnected orphan sub-chains are concatenated in the order their
chain-start appears in the original blocks list.
"""
pointed_to: set[str] = set()
for label in orphaned_labels:
block = label_to_block[label]
if block.next_block_label and block.next_block_label in orphaned_labels:
pointed_to.add(block.next_block_label)
# Chain starts are orphans not pointed to by another orphan.
# Preserve original array order for deterministic stitching.
chain_starts = [b.label for b in blocks if b.label in orphaned_labels and b.label not in pointed_to]
# If all orphans point to each other (cycle), pick the first in array order.
if not chain_starts:
chain_starts = [next(b.label for b in blocks if b.label in orphaned_labels)]
ordered: list[str] = []
visited: set[str] = set()
for start in chain_starts:
current: str | None = start
while current and current in orphaned_labels and current not in visited:
visited.add(current)
ordered.append(current)
current = label_to_block[current].next_block_label
# Append any remaining orphans not reached (multiple cycles).
for block in blocks:
if block.label in orphaned_labels and block.label not in visited:
ordered.append(block.label)
# Re-link the orphan chain so it forms a single connected path.
# This may overwrite an orphan's original next_block_label that pointed to a
# reachable block (a merge/join pattern). Log when this happens.
for i in range(len(ordered) - 1):
old_target = label_to_block[ordered[i]].next_block_label
new_target = ordered[i + 1]
if old_target and old_target != new_target and old_target not in orphaned_labels:
LOG.info(
"Orphan re-link overwrites cross-chain reference",
block=ordered[i],
old_target=old_target,
new_target=new_target,
)
label_to_block[ordered[i]].next_block_label = new_target
if ordered:
old_last_target = label_to_block[ordered[-1]].next_block_label
if old_last_target and old_last_target not in orphaned_labels:
LOG.info(
"Orphan chain terminal overwrites cross-chain reference",
block=ordered[-1],
old_target=old_last_target,
)
label_to_block[ordered[-1]].next_block_label = None
return ordered
def _repair_next_block_label_chain(blocks: list[BlockYAML]) -> None:
"""Ensure all top-level blocks form a single acyclic chain from blocks[0].
Repairs two classes of LLM mistakes:
1. Circular references — breaks cycles so the chain has a proper terminal block.
2. Disconnected paths — stitches orphaned blocks onto the end of the reachable chain.
Recursively repairs nested ForLoopBlockYAML.loop_blocks at all depths.
Mutates *blocks* in place.
"""
if len(blocks) <= 1:
# Still recurse into loop_blocks even for single-block lists
for block in blocks:
if isinstance(block, ForLoopBlockYAML) and block.loop_blocks:
_repair_next_block_label_chain(block.loop_blocks)
return
# Warn on duplicate labels — the dict comprehension silently keeps the last
# occurrence, so earlier blocks with the same label become invisible.
seen_labels: set[str] = set()
for block in blocks:
if block.label in seen_labels:
LOG.warning("Copilot produced duplicate block label", label=block.label)
seen_labels.add(block.label)
label_to_block: dict[str, BlockYAML] = {block.label: block for block in blocks}
all_labels = set(label_to_block.keys())
# Phase 1: break any circular references reachable from the first block.
# Note: cycles among orphaned blocks (unreachable from blocks[0]) are handled
# implicitly by _order_orphaned_blocks via its visited set and re-linking logic.
_break_cycles(blocks[0].label, label_to_block)
# Phase 2: find orphaned (unreachable) blocks and stitch them to the end.
reachable: set[str] = set()
_collect_reachable(blocks[0].label, label_to_block, reachable)
orphaned_labels = all_labels - reachable
if orphaned_labels:
LOG.warning(
"Copilot produced disconnected workflow blocks, repairing chain",
orphaned_labels=sorted(orphaned_labels),
reachable_labels=sorted(reachable),
)
terminal_label = _find_terminal_label(blocks[0].label, label_to_block, all_labels)
ordered_orphan_labels = _order_orphaned_blocks(orphaned_labels, label_to_block, all_labels, blocks)
if terminal_label and ordered_orphan_labels:
label_to_block[terminal_label].next_block_label = ordered_orphan_labels[0]
# Phase 3: recursively repair nested ForLoopBlockYAML.loop_blocks.
for block in blocks:
if isinstance(block, ForLoopBlockYAML) and block.loop_blocks:
_repair_next_block_label_chain(block.loop_blocks)
def _process_workflow_yaml(
workflow_id: str,
workflow_permanent_id: str,
organization_id: str,
workflow_yaml: str,
) -> Workflow:
parsed_yaml = safe_load_no_dates(workflow_yaml)
# Fixing trivial common LLM mistakes
workflow_definition = parsed_yaml.get("workflow_definition", None)
if workflow_definition:
blocks = workflow_definition.get("blocks", [])
for block in blocks:
block["title"] = block.get("title", "")
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
# Post-processing
for block in workflow_yaml_request.workflow_definition.blocks:
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
block.navigation_goal = DEFAULT_LOGIN_PROMPT
workflow_yaml_request.workflow_definition.parameters = [
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
]
_repair_next_block_label_chain(workflow_yaml_request.workflow_definition.blocks)
updated_workflow_definition = convert_workflow_definition(
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
workflow_id=workflow_id,
)
now = datetime.now(timezone.utc)
return Workflow(
workflow_id=workflow_id,
organization_id=organization_id,
title=workflow_yaml_request.title or "",
workflow_permanent_id=workflow_permanent_id,
version=1,
is_saved_task=workflow_yaml_request.is_saved_task,
description=workflow_yaml_request.description,
workflow_definition=updated_workflow_definition,
proxy_location=workflow_yaml_request.proxy_location,
webhook_callback_url=workflow_yaml_request.webhook_callback_url,
persist_browser_session=workflow_yaml_request.persist_browser_session or False,
model=workflow_yaml_request.model,
max_screenshot_scrolls=workflow_yaml_request.max_screenshot_scrolls,
extra_http_headers=workflow_yaml_request.extra_http_headers,
run_with=workflow_yaml_request.run_with,
ai_fallback=workflow_yaml_request.ai_fallback,
cache_key=workflow_yaml_request.cache_key,
run_sequentially=workflow_yaml_request.run_sequentially,
sequential_key=workflow_yaml_request.sequential_key,
created_at=now,
modified_at=now,
)
async def _new_copilot_chat_post(
request: Request,
chat_request: WorkflowCopilotChatRequest,
organization: Organization,
) -> EventSourceResponse:
"""ENABLE_WORKFLOW_COPILOT_V2 dispatch target.
Runs the openai-agents-SDK copilot (skyvern.forge.sdk.copilot.agent) and
streams responses in the same SSE shape the frontend consumes. On
mid-stream failure (HTTPException, LLMProviderError, asyncio.CancelledError,
or unexpected exception), rolls the workflow definition back to
``original_workflow`` via ``_restore_workflow_definition`` to avoid leaving
a half-persisted draft.
"""
async def stream_handler(stream: EventSourceStream) -> None:
LOG.info(
"Workflow copilot v2 chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
workflow_run_id=chat_request.workflow_run_id,
message=chat_request.message,
workflow_yaml_length=len(chat_request.workflow_yaml),
organization_id=organization.organization_id,
)
original_workflow: Workflow | None = None
chat = None
agent_result: Any = None
try:
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Processing...",
timestamp=datetime.now(timezone.utc),
)
)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_request.workflow_copilot_chat_id = chat.workflow_copilot_chat_id
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
for message in reversed(chat_messages):
if message.global_llm_context is not None:
global_llm_context = message.global_llm_context
break
if chat.proposed_workflow and chat.proposed_workflow.get("_copilot_yaml"):
chat_request.workflow_yaml = chat.proposed_workflow["_copilot_yaml"]
block_infos, debug_html = await _get_new_copilot_block_infos(
organization.organization_id, chat_request.workflow_run_id
)
debug_run_info_text = ""
if block_infos:
parts: list[str] = []
for bi in block_infos:
block_text = f"Block: {bi.block_label} ({bi.block_type}) — {bi.block_status}"
if bi.failure_reason:
block_text += f"\n Failure Reason: {bi.failure_reason}"
if bi.output:
block_text += f"\n Output: {bi.output}"
parts.append(block_text)
debug_run_info_text = "\n".join(parts)
if debug_html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_html}"
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Thinking...",
timestamp=datetime.now(timezone.utc),
)
)
# No early exit on disconnect (SKY-8986): the agent runs to
# completion even after the SSE stream drops so its reply is
# persisted to the chat history and visible after reconnect.
original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id(
workflow_permanent_id=chat_request.workflow_permanent_id,
organization_id=organization.organization_id,
)
if not original_workflow:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")
chat_request.workflow_id = original_workflow.workflow_id
llm_api_handler = (
await get_llm_handler_for_prompt_type(
"workflow-copilot", chat_request.workflow_permanent_id, organization.organization_id
)
or app.LLM_API_HANDLER
)
api_key = request.headers.get("x-api-key")
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
agent_result = await run_copilot_agent(
stream=stream,
organization_id=organization.organization_id,
chat_request=chat_request,
chat_history=convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
global_llm_context=global_llm_context,
debug_run_info_text=debug_run_info_text,
llm_api_handler=llm_api_handler,
api_key=api_key,
security_rules=security_rules,
)
user_response = agent_result.user_response
updated_workflow = agent_result.updated_workflow
updated_global_llm_context = agent_result.global_llm_context
# Persist rollback / proposed-workflow state and the chat
# messages regardless of whether the SSE client is still
# connected: the user needs to see the reply on reconnect.
# SKY-8986: client disconnect used to short-circuit this block
# and leave the chat history without the AI response.
#
# SKY-9143: restore runs outside the auto_accept wrapper so
# auto-accept turns that ended without a viable proposal still
# roll back a mid-turn _update_workflow write. The Accept/Reject
# panel state below stays gated on auto_accept — the frontend
# applies proposals via applyWorkflowUpdate when auto-accept is
# on.
restored = _should_restore_persisted_workflow(chat.auto_accept, agent_result)
if restored:
await _restore_workflow_definition(original_workflow, organization.organization_id)
if chat.auto_accept is not True:
if updated_workflow:
proposed_data = updated_workflow.model_dump(mode="json")
if agent_result.workflow_yaml:
proposed_data["_copilot_yaml"] = agent_result.workflow_yaml
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
proposed_workflow=proposed_data,
)
elif (
restored or getattr(agent_result, "clear_proposed_workflow", False)
) and chat.proposed_workflow is not None:
# Null any previously-persisted proposed_workflow so a
# page reload does not resurrect a stale Accept/Reject
# card next to an assistant message that just explained
# why no verified proposal is available. Covers:
# * feasibility-gate fast-path clarifications, and
# * SKY-9143 strict-gate turns where a mid-turn draft was
# rolled back (``restored=True``).
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
proposed_workflow=None,
)
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.AI,
content=user_response,
global_llm_context=updated_global_llm_context,
)
await stream.send(
WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
response_time=assistant_message.created_at,
total_tokens=getattr(agent_result, "total_tokens", None),
response_type=getattr(agent_result, "response_type", "REPLY"),
)
)
except HTTPException as exc:
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id)
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error=exc.detail,
)
)
except LLMProviderError as exc:
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id)
LOG.error(
"LLM provider error (copilot v2)",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="Failed to process your request. Please try again.",
)
)
except asyncio.CancelledError:
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await asyncio.shield(_restore_workflow_definition(original_workflow, organization.organization_id))
LOG.info(
"Client disconnected during workflow copilot v2",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
except Exception as exc:
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
await _restore_workflow_definition(original_workflow, organization.organization_id)
LOG.error(
"Unexpected error in workflow copilot v2",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="An error occurred. Please try again.",
)
)
return FastAPIEventSourceStream.create(request, stream_handler)
COPILOT_V2_FLAG_KEY = "ENABLE_WORKFLOW_COPILOT_V2"
async def _should_use_copilot_v2(organization: Organization, workflow_permanent_id: str) -> bool:
if settings.ENABLE_WORKFLOW_COPILOT_V2:
return True
try:
# distinct_id is the org (not a run id) because this gate is an org-sticky rollout:
# copilot chat may not have a stable run at dispatch time, and we want each org to
# see the same path across sessions. Contrast with backend.md's default of run-level
# ids for per-run randomized experiments.
return await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
COPILOT_V2_FLAG_KEY,
distinct_id=organization.organization_id,
properties={"organization_id": organization.organization_id},
)
except Exception:
LOG.exception(
"Failed to evaluate copilot-v2 feature flag; falling back to legacy copilot",
organization_id=organization.organization_id,
workflow_permanent_id=workflow_permanent_id,
)
return False
@base_router.post("/workflow/copilot/chat-post", include_in_schema=False)
async def workflow_copilot_chat_post(
request: Request,
chat_request: WorkflowCopilotChatRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> EventSourceResponse:
if await _should_use_copilot_v2(organization, chat_request.workflow_permanent_id):
return await _new_copilot_chat_post(request, chat_request, organization)
async def stream_handler(stream: EventSourceStream) -> None:
LOG.info(
"Workflow copilot chat request",
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
workflow_run_id=chat_request.workflow_run_id,
message=chat_request.message,
workflow_yaml_length=len(chat_request.workflow_yaml),
organization_id=organization.organization_id,
)
try:
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Processing...",
timestamp=datetime.now(timezone.utc),
)
)
if chat_request.workflow_copilot_chat_id:
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
organization_id=organization.organization_id,
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
)
if not chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
else:
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=chat_request.workflow_permanent_id,
)
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
)
global_llm_context = None
for message in reversed(chat_messages):
if message.global_llm_context is not None:
global_llm_context = message.global_llm_context
break
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
# Format debug run info for prompt
debug_run_info_text = ""
if debug_run_info:
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
debug_run_info_text += f" Status: {debug_run_info.block_status}"
if debug_run_info.failure_reason:
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
if debug_run_info.html:
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
await stream.send(
WorkflowCopilotProcessingUpdate(
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
status="Thinking...",
timestamp=datetime.now(timezone.utc),
)
)
# SKY-8986: do not short-circuit on client disconnect. The LLM
# call and the DB persistence below must complete so the reply
# is in the chat history when the user reconnects.
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
stream,
organization.organization_id,
chat_request,
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
global_llm_context,
debug_run_info_text,
)
if updated_workflow and chat.auto_accept is not True:
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
proposed_workflow=updated_workflow.model_dump(mode="json"),
)
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.USER,
content=chat_request.message,
)
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
organization_id=chat.organization_id,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
sender=WorkflowCopilotChatSender.AI,
content=user_response,
global_llm_context=updated_global_llm_context,
)
await stream.send(
WorkflowCopilotStreamResponseUpdate(
type=WorkflowCopilotStreamMessageType.RESPONSE,
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
message=user_response,
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
response_time=assistant_message.created_at,
)
)
except HTTPException as exc:
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error=exc.detail,
)
)
except LLMProviderError as exc:
LOG.error(
"LLM provider error",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="Failed to process your request. Please try again.",
)
)
except Exception as exc:
LOG.error(
"Unexpected error in workflow copilot",
organization_id=organization.organization_id,
error=str(exc),
exc_info=True,
)
await stream.send(
WorkflowCopilotStreamErrorUpdate(
type=WorkflowCopilotStreamMessageType.ERROR,
error="An error occurred. Please try again.",
)
)
return FastAPIEventSourceStream.create(request, stream_handler)
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)
async def workflow_copilot_chat_history(
workflow_permanent_id: str,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowCopilotChatHistoryResponse:
latest_chat = await app.DATABASE.workflow_params.get_latest_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_permanent_id=workflow_permanent_id,
)
if latest_chat:
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
latest_chat.workflow_copilot_chat_id
)
else:
chat_messages = []
return WorkflowCopilotChatHistoryResponse(
workflow_copilot_chat_id=latest_chat.workflow_copilot_chat_id if latest_chat else None,
chat_history=convert_to_history_messages(chat_messages),
proposed_workflow=latest_chat.proposed_workflow if latest_chat else None,
auto_accept=latest_chat.auto_accept if latest_chat else None,
)
@base_router.post(
"/workflow/copilot/clear-proposed-workflow", include_in_schema=False, status_code=status.HTTP_204_NO_CONTENT
)
async def workflow_copilot_clear_proposed_workflow(
clear_request: WorkflowCopilotClearProposedWorkflowRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> None:
updated_chat = await app.DATABASE.workflow_params.update_workflow_copilot_chat(
organization_id=organization.organization_id,
workflow_copilot_chat_id=clear_request.workflow_copilot_chat_id,
proposed_workflow=None,
auto_accept=clear_request.auto_accept,
)
if not updated_chat:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
def convert_to_history_messages(
messages: list[WorkflowCopilotChatMessage],
) -> list[WorkflowCopilotChatHistoryMessage]:
return [
WorkflowCopilotChatHistoryMessage(
sender=message.sender,
content=message.content,
created_at=message.created_at,
)
for message in messages
]
@base_router.post("/workflow/copilot/convert-yaml-to-blocks", include_in_schema=False)
async def workflow_copilot_convert_yaml_to_blocks(
request: WorkflowYAMLConversionRequest,
organization: Organization = Depends(org_auth_service.get_current_org),
) -> WorkflowYAMLConversionResponse:
"""
Convert workflow definition YAML to blocks format for comparison view.
This endpoint is used by the frontend to convert YAML to the proper blocks structure
that the comparison panel expects.
"""
try:
parsed_yaml = safe_load_no_dates(request.workflow_definition_yaml)
workflow_definition_yaml = WorkflowDefinitionYAML.model_validate(parsed_yaml)
_repair_next_block_label_chain(workflow_definition_yaml.blocks)
workflow_definition = convert_workflow_definition(
workflow_definition_yaml=workflow_definition_yaml,
workflow_id=request.workflow_id,
)
return WorkflowYAMLConversionResponse(workflow_definition=workflow_definition.model_dump(mode="json"))
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Failed to convert workflow YAML: {str(e)}",
)