mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-26 10:41:14 +00:00
1224 lines
51 KiB
Python
1224 lines
51 KiB
Python
import asyncio
|
|
import time
|
|
from dataclasses import dataclass
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import structlog
|
|
import yaml
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from pydantic import ValidationError
|
|
from sse_starlette import EventSourceResponse
|
|
|
|
from skyvern.config import settings
|
|
from skyvern.constants import DEFAULT_LOGIN_PROMPT
|
|
from skyvern.forge import app
|
|
from skyvern.forge.prompts import prompt_engine
|
|
from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler
|
|
from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError
|
|
from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType
|
|
from skyvern.forge.sdk.copilot.agent import run_copilot_agent
|
|
from skyvern.forge.sdk.copilot.output_utils import truncate_output
|
|
from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type
|
|
from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream
|
|
from skyvern.forge.sdk.routes.routers import base_router
|
|
from skyvern.forge.sdk.schemas.organizations import Organization
|
|
from skyvern.forge.sdk.schemas.workflow_copilot import (
|
|
WorkflowCopilotChatHistoryMessage,
|
|
WorkflowCopilotChatHistoryResponse,
|
|
WorkflowCopilotChatMessage,
|
|
WorkflowCopilotChatRequest,
|
|
WorkflowCopilotChatSender,
|
|
WorkflowCopilotClearProposedWorkflowRequest,
|
|
WorkflowCopilotProcessingUpdate,
|
|
WorkflowCopilotStreamErrorUpdate,
|
|
WorkflowCopilotStreamMessageType,
|
|
WorkflowCopilotStreamResponseUpdate,
|
|
WorkflowYAMLConversionRequest,
|
|
WorkflowYAMLConversionResponse,
|
|
)
|
|
from skyvern.forge.sdk.services import org_auth_service
|
|
from skyvern.forge.sdk.workflow.exceptions import BaseWorkflowHTTPException
|
|
from skyvern.forge.sdk.workflow.models.parameter import ParameterType
|
|
from skyvern.forge.sdk.workflow.models.workflow import Workflow
|
|
from skyvern.forge.sdk.workflow.workflow_definition_converter import convert_workflow_definition
|
|
from skyvern.schemas.workflows import (
|
|
BlockYAML,
|
|
BranchConditionYAML,
|
|
ConditionalBlockYAML,
|
|
ForLoopBlockYAML,
|
|
LoginBlockYAML,
|
|
WorkflowCreateYAMLRequest,
|
|
WorkflowDefinitionYAML,
|
|
)
|
|
from skyvern.utils.strings import escape_code_fences
|
|
from skyvern.utils.yaml_loader import safe_load_no_dates
|
|
|
|
WORKFLOW_KNOWLEDGE_BASE_PATH = Path("skyvern/forge/prompts/skyvern/workflow_knowledge_base.txt")
|
|
CHAT_HISTORY_CONTEXT_MESSAGES = 10
|
|
|
|
LOG = structlog.get_logger()
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RunInfo:
|
|
block_label: str | None
|
|
block_type: str
|
|
block_status: str | None
|
|
failure_reason: str | None
|
|
html: str | None
|
|
|
|
|
|
# New-copilot richer block shape (used only from the ENABLE_WORKFLOW_COPILOT_V2
|
|
# dispatch path). Kept side-by-side with the old RunInfo so the old-copilot
|
|
# body stays untouched; consolidation is SKY-8916's job.
|
|
@dataclass(frozen=True)
|
|
class BlockRunInfo:
|
|
block_label: str | None
|
|
block_type: str
|
|
block_status: str | None
|
|
failure_reason: str | None
|
|
output: str | None
|
|
|
|
|
|
def _should_restore_persisted_workflow(auto_accept: bool | None, agent_result: object | None) -> bool:
|
|
"""Return True when a persisted draft should be rolled back.
|
|
|
|
SKY-9143: when the agent decided not to ship a proposal this turn
|
|
(``updated_workflow is None``) but ``_update_workflow`` already committed
|
|
a YAML to ``workflow_definition``, we must restore the original even under
|
|
``auto_accept=True`` — otherwise an unverified edit becomes the live
|
|
workflow silently.
|
|
"""
|
|
if not bool(getattr(agent_result, "workflow_was_persisted", False)):
|
|
return False
|
|
if getattr(agent_result, "updated_workflow", None) is None:
|
|
return True
|
|
return auto_accept is not True
|
|
|
|
|
|
async def _restore_workflow_definition(original_workflow: Workflow | None, organization_id: str) -> None:
|
|
"""Roll the workflow back to ``original_workflow``.
|
|
|
|
Unconditional restore helper. Callers must first gate this with
|
|
``_should_restore_persisted_workflow`` so success, disconnect, and exception
|
|
paths all apply the same rollback rule: only restore when the user did not
|
|
opt into auto-accept AND the agent loop actually persisted a mid-request
|
|
draft.
|
|
"""
|
|
if not original_workflow:
|
|
return
|
|
try:
|
|
await app.WORKFLOW_SERVICE.update_workflow_definition(
|
|
workflow_id=original_workflow.workflow_id,
|
|
organization_id=organization_id,
|
|
title=original_workflow.title,
|
|
description=original_workflow.description,
|
|
workflow_definition=original_workflow.workflow_definition,
|
|
)
|
|
except Exception:
|
|
LOG.warning(
|
|
"Failed to restore original workflow",
|
|
workflow_id=original_workflow.workflow_id,
|
|
exc_info=True,
|
|
)
|
|
|
|
|
|
async def _get_debug_artifact(organization_id: str, workflow_run_id: str) -> Artifact | None:
|
|
artifacts = await app.DATABASE.artifacts.get_artifacts_for_run(
|
|
run_id=workflow_run_id, organization_id=organization_id, artifact_types=[ArtifactType.VISIBLE_ELEMENTS_TREE]
|
|
)
|
|
return artifacts[0] if isinstance(artifacts, list) and artifacts else None
|
|
|
|
|
|
async def _get_debug_run_info(organization_id: str, workflow_run_id: str | None) -> RunInfo | None:
|
|
if not workflow_run_id:
|
|
return None
|
|
|
|
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
|
workflow_run_id=workflow_run_id, organization_id=organization_id
|
|
)
|
|
if not blocks:
|
|
return None
|
|
|
|
block = blocks[0]
|
|
|
|
artifact = await _get_debug_artifact(organization_id, workflow_run_id)
|
|
if artifact:
|
|
artifact_bytes = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
|
html = artifact_bytes.decode("utf-8") if artifact_bytes else None
|
|
else:
|
|
html = None
|
|
|
|
return RunInfo(
|
|
block_label=block.label,
|
|
block_type=block.block_type.name,
|
|
block_status=block.status,
|
|
failure_reason=block.failure_reason,
|
|
html=html,
|
|
)
|
|
|
|
|
|
async def _get_new_copilot_block_infos(
|
|
organization_id: str, workflow_run_id: str | None
|
|
) -> tuple[list[BlockRunInfo], str | None]:
|
|
"""Variant of _get_debug_run_info used by the ENABLE_WORKFLOW_COPILOT_V2 path.
|
|
|
|
Returns a list of per-block records plus the run's VISIBLE_ELEMENTS_TREE
|
|
HTML artifact. Coexists with _get_debug_run_info which returns the
|
|
simpler single-block shape used by the old-copilot path.
|
|
"""
|
|
if not workflow_run_id:
|
|
return [], None
|
|
|
|
blocks = await app.DATABASE.observer.get_workflow_run_blocks(
|
|
workflow_run_id=workflow_run_id, organization_id=organization_id
|
|
)
|
|
if not blocks:
|
|
return [], None
|
|
|
|
block_infos: list[BlockRunInfo] = []
|
|
for block in blocks:
|
|
block_type_name = block.block_type.name if hasattr(block.block_type, "name") else str(block.block_type)
|
|
block_infos.append(
|
|
BlockRunInfo(
|
|
block_label=block.label,
|
|
block_type=block_type_name,
|
|
block_status=block.status,
|
|
failure_reason=block.failure_reason,
|
|
output=truncate_output(getattr(block, "output", None)),
|
|
)
|
|
)
|
|
|
|
artifact = await _get_debug_artifact(organization_id, workflow_run_id)
|
|
html: str | None = None
|
|
if artifact:
|
|
artifact_bytes = await app.ARTIFACT_MANAGER.retrieve_artifact(artifact)
|
|
html = artifact_bytes.decode("utf-8") if artifact_bytes else None
|
|
|
|
return block_infos, html
|
|
|
|
|
|
def _format_chat_history(chat_history: list[WorkflowCopilotChatHistoryMessage]) -> str:
|
|
chat_history_text = ""
|
|
if chat_history:
|
|
history_lines = [f"{msg.sender}: {msg.content}" for msg in chat_history]
|
|
chat_history_text = "\n".join(history_lines)
|
|
return chat_history_text
|
|
|
|
|
|
def _parse_llm_response(llm_response: dict[str, Any] | Any) -> Any:
|
|
if isinstance(llm_response, dict) and "output" in llm_response:
|
|
action_data = llm_response["output"]
|
|
else:
|
|
action_data = llm_response
|
|
|
|
if not isinstance(action_data, dict):
|
|
LOG.error(
|
|
"LLM response is not valid JSON",
|
|
response_type=type(action_data).__name__,
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail="Invalid response from LLM",
|
|
)
|
|
return action_data
|
|
|
|
|
|
async def copilot_call_llm(
|
|
stream: EventSourceStream,
|
|
organization_id: str,
|
|
chat_request: WorkflowCopilotChatRequest,
|
|
chat_history: list[WorkflowCopilotChatHistoryMessage],
|
|
global_llm_context: str | None,
|
|
debug_run_info_text: str,
|
|
) -> tuple[str, Workflow | None, str | None]:
|
|
chat_history_text = _format_chat_history(chat_history)
|
|
|
|
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
|
|
|
|
# Render system prompt (trusted content only, security rules injected via AgentFunction)
|
|
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
|
|
system_prompt = prompt_engine.load_prompt(
|
|
template="workflow-copilot-system",
|
|
workflow_knowledge_base=workflow_knowledge_base,
|
|
current_datetime=datetime.now(timezone.utc).isoformat(),
|
|
security_rules=security_rules,
|
|
)
|
|
|
|
# Render user prompt (untrusted content, each variable in code fences)
|
|
# Escape triple backticks to prevent code fence breakout
|
|
user_prompt = prompt_engine.load_prompt(
|
|
template="workflow-copilot-user",
|
|
workflow_yaml=escape_code_fences(chat_request.workflow_yaml or ""),
|
|
user_message=escape_code_fences(chat_request.message),
|
|
chat_history=escape_code_fences(chat_history_text),
|
|
global_llm_context=escape_code_fences(global_llm_context or ""),
|
|
debug_run_info=escape_code_fences(debug_run_info_text),
|
|
)
|
|
|
|
LOG.info(
|
|
"Calling LLM",
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
workflow_id=chat_request.workflow_id,
|
|
user_message_len=len(chat_request.message),
|
|
user_message=chat_request.message,
|
|
workflow_yaml_len=len(chat_request.workflow_yaml or ""),
|
|
workflow_yaml=chat_request.workflow_yaml or "",
|
|
chat_history_len=len(chat_history_text),
|
|
chat_history=chat_history_text,
|
|
global_llm_context_len=len(global_llm_context or ""),
|
|
global_llm_context=global_llm_context or "",
|
|
workflow_knowledge_base_len=len(workflow_knowledge_base),
|
|
debug_run_info_len=len(debug_run_info_text),
|
|
system_prompt_len=len(system_prompt),
|
|
user_prompt_len=len(user_prompt),
|
|
)
|
|
llm_api_handler = (
|
|
await get_llm_handler_for_prompt_type("workflow-copilot", chat_request.workflow_permanent_id, organization_id)
|
|
or app.LLM_API_HANDLER
|
|
)
|
|
llm_start_time = time.monotonic()
|
|
llm_response = await llm_api_handler(
|
|
prompt=user_prompt,
|
|
prompt_name="workflow-copilot",
|
|
organization_id=organization_id,
|
|
system_prompt=system_prompt,
|
|
)
|
|
LOG.info(
|
|
"LLM response",
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
workflow_id=chat_request.workflow_id,
|
|
duration_seconds=time.monotonic() - llm_start_time,
|
|
user_message_len=len(chat_request.message),
|
|
workflow_yaml_len=len(chat_request.workflow_yaml or ""),
|
|
chat_history_len=len(chat_history_text),
|
|
global_llm_context_len=len(global_llm_context or ""),
|
|
debug_run_info_len=len(debug_run_info_text),
|
|
workflow_knowledge_base_len=len(workflow_knowledge_base),
|
|
llm_response_len=len(llm_response),
|
|
llm_response=llm_response,
|
|
)
|
|
|
|
action_data = _parse_llm_response(llm_response)
|
|
|
|
action_type = action_data.get("type")
|
|
user_response_value = action_data.get("user_response")
|
|
if user_response_value is None:
|
|
user_response = "I received your request but I'm not sure how to help. Could you rephrase?"
|
|
else:
|
|
user_response = str(user_response_value)
|
|
LOG.info(
|
|
"LLM response received",
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
workflow_id=chat_request.workflow_id,
|
|
organization_id=organization_id,
|
|
action_type=action_type,
|
|
)
|
|
|
|
global_llm_context = action_data.get("global_llm_context")
|
|
if global_llm_context is not None:
|
|
global_llm_context = str(global_llm_context)
|
|
|
|
if action_type == "REPLACE_WORKFLOW":
|
|
llm_workflow_yaml = action_data.get("workflow_yaml", "")
|
|
try:
|
|
updated_workflow = _process_workflow_yaml(
|
|
workflow_id=chat_request.workflow_id,
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
organization_id=organization_id,
|
|
workflow_yaml=llm_workflow_yaml,
|
|
)
|
|
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
|
|
await stream.send(
|
|
WorkflowCopilotProcessingUpdate(
|
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
|
status="Validating workflow definition...",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
corrected_workflow_yaml = await _auto_correct_workflow_yaml(
|
|
llm_api_handler=llm_api_handler,
|
|
organization_id=organization_id,
|
|
user_response=user_response,
|
|
workflow_yaml=llm_workflow_yaml,
|
|
chat_history=chat_history,
|
|
global_llm_context=global_llm_context,
|
|
debug_run_info_text=debug_run_info_text,
|
|
error=e,
|
|
)
|
|
updated_workflow = _process_workflow_yaml(
|
|
workflow_id=chat_request.workflow_id,
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
organization_id=organization_id,
|
|
workflow_yaml=corrected_workflow_yaml,
|
|
)
|
|
|
|
return user_response, updated_workflow, global_llm_context
|
|
elif action_type == "REPLY":
|
|
return user_response, None, global_llm_context
|
|
elif action_type == "ASK_QUESTION":
|
|
return user_response, None, global_llm_context
|
|
else:
|
|
LOG.error(
|
|
"Unknown action type from LLM",
|
|
organization_id=organization_id,
|
|
action_type=action_type,
|
|
)
|
|
return "I received your request but I'm not sure how to help. Could you rephrase?", None, None
|
|
|
|
|
|
async def _auto_correct_workflow_yaml(
|
|
llm_api_handler: LLMAPIHandler,
|
|
organization_id: str,
|
|
user_response: str,
|
|
workflow_yaml: str,
|
|
chat_history: list[WorkflowCopilotChatHistoryMessage],
|
|
global_llm_context: str | None,
|
|
debug_run_info_text: str,
|
|
error: Exception,
|
|
) -> str:
|
|
failure_reason = f"{error.__class__.__name__}: {error}"
|
|
|
|
new_chat_history = chat_history[:]
|
|
new_chat_history.append(
|
|
WorkflowCopilotChatHistoryMessage(
|
|
sender=WorkflowCopilotChatSender.AI,
|
|
content=user_response,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
|
|
workflow_knowledge_base = WORKFLOW_KNOWLEDGE_BASE_PATH.read_text(encoding="utf-8")
|
|
|
|
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
|
|
system_prompt = prompt_engine.load_prompt(
|
|
template="workflow-copilot-system",
|
|
workflow_knowledge_base=workflow_knowledge_base,
|
|
current_datetime=datetime.now(timezone.utc).isoformat(),
|
|
security_rules=security_rules,
|
|
)
|
|
|
|
user_prompt = prompt_engine.load_prompt(
|
|
template="workflow-copilot-user",
|
|
workflow_yaml=escape_code_fences(workflow_yaml),
|
|
user_message=escape_code_fences(f"Workflow YAML parsing failed, please fix it: {failure_reason}"),
|
|
chat_history=escape_code_fences(_format_chat_history(new_chat_history)),
|
|
global_llm_context=escape_code_fences(global_llm_context or ""),
|
|
debug_run_info=escape_code_fences(debug_run_info_text),
|
|
)
|
|
|
|
llm_start_time = time.monotonic()
|
|
llm_response = await llm_api_handler(
|
|
prompt=user_prompt,
|
|
prompt_name="workflow-copilot",
|
|
organization_id=organization_id,
|
|
system_prompt=system_prompt,
|
|
)
|
|
LOG.info(
|
|
"Auto-correction LLM response",
|
|
duration_seconds=time.monotonic() - llm_start_time,
|
|
llm_response_len=len(llm_response),
|
|
llm_response=llm_response,
|
|
)
|
|
action_data = _parse_llm_response(llm_response)
|
|
|
|
return action_data.get("workflow_yaml", workflow_yaml)
|
|
|
|
|
|
def _collect_reachable(
|
|
start_label: str,
|
|
label_to_block: dict[str, BlockYAML],
|
|
reachable: set[str],
|
|
) -> None:
|
|
"""Walk the next_block_label chain from start_label, collecting all reachable labels.
|
|
|
|
For conditional blocks, also follows branch target chains recursively.
|
|
|
|
The ``current not in reachable`` loop guard means the main-chain walk
|
|
stops early if we hit a node already collected via a branch recursion.
|
|
This is correct — those downstream nodes and their successors are
|
|
already in ``reachable`` — but callers should be aware of the coupling.
|
|
"""
|
|
current: str | None = start_label
|
|
while current and current in label_to_block and current not in reachable:
|
|
reachable.add(current)
|
|
block = label_to_block[current]
|
|
if isinstance(block, ConditionalBlockYAML):
|
|
for branch in block.branch_conditions:
|
|
if branch.next_block_label and branch.next_block_label not in reachable:
|
|
_collect_reachable(branch.next_block_label, label_to_block, reachable)
|
|
current = block.next_block_label
|
|
|
|
|
|
def _break_cycles(
|
|
start_label: str,
|
|
label_to_block: dict[str, BlockYAML],
|
|
) -> bool:
|
|
"""Detect and break circular references in the block chain using DFS.
|
|
|
|
Uses a recursion stack to distinguish true back-edges (cycles) from merge
|
|
points (two branches converging on the same block). When a back-edge is
|
|
found the offending ``next_block_label`` is set to ``None``, breaking the
|
|
cycle. Handles both the main chain and conditional branch chains.
|
|
|
|
Note: this function operates on a single level of blocks. It does **not**
|
|
recurse into ``ForLoopBlockYAML.loop_blocks``; nested loops are handled
|
|
by the recursive ``_repair_next_block_label_chain`` call in Phase 3.
|
|
|
|
Returns True if at least one cycle was broken.
|
|
"""
|
|
visited: set[str] = set()
|
|
rec_stack: set[str] = set()
|
|
found_cycle = False
|
|
|
|
def _follow_edge(target: str | None, edge_owner: BlockYAML | BranchConditionYAML, parent_label: str) -> None:
|
|
"""Follow an edge to *target*. *edge_owner* is the object whose
|
|
``next_block_label`` will be set to ``None`` when the target forms a
|
|
back-edge. *parent_label* is the block label that owns this edge
|
|
for logging."""
|
|
nonlocal found_cycle
|
|
if not target or target not in label_to_block:
|
|
return
|
|
if target in rec_stack:
|
|
is_branch = hasattr(edge_owner, "criteria")
|
|
LOG.warning(
|
|
"Copilot produced circular block chain, breaking cycle",
|
|
cycle_target=target,
|
|
broken_at=parent_label,
|
|
is_branch_condition=is_branch,
|
|
branch_expression=getattr(getattr(edge_owner, "criteria", None), "expression", None),
|
|
)
|
|
edge_owner.next_block_label = None
|
|
found_cycle = True
|
|
return
|
|
if target in visited:
|
|
return # merge point — not a cycle
|
|
_dfs(target)
|
|
|
|
def _dfs(label: str) -> None:
|
|
visited.add(label)
|
|
rec_stack.add(label)
|
|
block = label_to_block[label]
|
|
|
|
if isinstance(block, ConditionalBlockYAML):
|
|
for branch in block.branch_conditions:
|
|
_follow_edge(branch.next_block_label, branch, label)
|
|
|
|
_follow_edge(block.next_block_label, block, label)
|
|
rec_stack.discard(label)
|
|
|
|
if start_label in label_to_block:
|
|
_dfs(start_label)
|
|
return found_cycle
|
|
|
|
|
|
def _find_terminal_label(
|
|
start_label: str,
|
|
label_to_block: dict[str, BlockYAML],
|
|
all_labels: set[str],
|
|
) -> str | None:
|
|
"""Find the terminal block by walking the main chain from start_label."""
|
|
visited: set[str] = set()
|
|
current: str | None = start_label
|
|
while current and current in label_to_block and current not in visited:
|
|
visited.add(current)
|
|
block = label_to_block[current]
|
|
if block.next_block_label is None or block.next_block_label not in all_labels:
|
|
return current
|
|
current = block.next_block_label
|
|
return None
|
|
|
|
|
|
def _order_orphaned_blocks(
|
|
orphaned_labels: set[str],
|
|
label_to_block: dict[str, BlockYAML],
|
|
all_labels: set[str],
|
|
blocks: list[BlockYAML],
|
|
) -> list[str]:
|
|
"""Order orphaned blocks by following their internal next_block_label chains.
|
|
|
|
Multiple disconnected orphan sub-chains are concatenated in the order their
|
|
chain-start appears in the original blocks list.
|
|
"""
|
|
pointed_to: set[str] = set()
|
|
for label in orphaned_labels:
|
|
block = label_to_block[label]
|
|
if block.next_block_label and block.next_block_label in orphaned_labels:
|
|
pointed_to.add(block.next_block_label)
|
|
|
|
# Chain starts are orphans not pointed to by another orphan.
|
|
# Preserve original array order for deterministic stitching.
|
|
chain_starts = [b.label for b in blocks if b.label in orphaned_labels and b.label not in pointed_to]
|
|
|
|
# If all orphans point to each other (cycle), pick the first in array order.
|
|
if not chain_starts:
|
|
chain_starts = [next(b.label for b in blocks if b.label in orphaned_labels)]
|
|
|
|
ordered: list[str] = []
|
|
visited: set[str] = set()
|
|
for start in chain_starts:
|
|
current: str | None = start
|
|
while current and current in orphaned_labels and current not in visited:
|
|
visited.add(current)
|
|
ordered.append(current)
|
|
current = label_to_block[current].next_block_label
|
|
|
|
# Append any remaining orphans not reached (multiple cycles).
|
|
for block in blocks:
|
|
if block.label in orphaned_labels and block.label not in visited:
|
|
ordered.append(block.label)
|
|
|
|
# Re-link the orphan chain so it forms a single connected path.
|
|
# This may overwrite an orphan's original next_block_label that pointed to a
|
|
# reachable block (a merge/join pattern). Log when this happens.
|
|
for i in range(len(ordered) - 1):
|
|
old_target = label_to_block[ordered[i]].next_block_label
|
|
new_target = ordered[i + 1]
|
|
if old_target and old_target != new_target and old_target not in orphaned_labels:
|
|
LOG.info(
|
|
"Orphan re-link overwrites cross-chain reference",
|
|
block=ordered[i],
|
|
old_target=old_target,
|
|
new_target=new_target,
|
|
)
|
|
label_to_block[ordered[i]].next_block_label = new_target
|
|
if ordered:
|
|
old_last_target = label_to_block[ordered[-1]].next_block_label
|
|
if old_last_target and old_last_target not in orphaned_labels:
|
|
LOG.info(
|
|
"Orphan chain terminal overwrites cross-chain reference",
|
|
block=ordered[-1],
|
|
old_target=old_last_target,
|
|
)
|
|
label_to_block[ordered[-1]].next_block_label = None
|
|
|
|
return ordered
|
|
|
|
|
|
def _repair_next_block_label_chain(blocks: list[BlockYAML]) -> None:
|
|
"""Ensure all top-level blocks form a single acyclic chain from blocks[0].
|
|
|
|
Repairs two classes of LLM mistakes:
|
|
1. Circular references — breaks cycles so the chain has a proper terminal block.
|
|
2. Disconnected paths — stitches orphaned blocks onto the end of the reachable chain.
|
|
|
|
Recursively repairs nested ForLoopBlockYAML.loop_blocks at all depths.
|
|
Mutates *blocks* in place.
|
|
"""
|
|
if len(blocks) <= 1:
|
|
# Still recurse into loop_blocks even for single-block lists
|
|
for block in blocks:
|
|
if isinstance(block, ForLoopBlockYAML) and block.loop_blocks:
|
|
_repair_next_block_label_chain(block.loop_blocks)
|
|
return
|
|
|
|
# Warn on duplicate labels — the dict comprehension silently keeps the last
|
|
# occurrence, so earlier blocks with the same label become invisible.
|
|
seen_labels: set[str] = set()
|
|
for block in blocks:
|
|
if block.label in seen_labels:
|
|
LOG.warning("Copilot produced duplicate block label", label=block.label)
|
|
seen_labels.add(block.label)
|
|
|
|
label_to_block: dict[str, BlockYAML] = {block.label: block for block in blocks}
|
|
all_labels = set(label_to_block.keys())
|
|
|
|
# Phase 1: break any circular references reachable from the first block.
|
|
# Note: cycles among orphaned blocks (unreachable from blocks[0]) are handled
|
|
# implicitly by _order_orphaned_blocks via its visited set and re-linking logic.
|
|
_break_cycles(blocks[0].label, label_to_block)
|
|
|
|
# Phase 2: find orphaned (unreachable) blocks and stitch them to the end.
|
|
reachable: set[str] = set()
|
|
_collect_reachable(blocks[0].label, label_to_block, reachable)
|
|
|
|
orphaned_labels = all_labels - reachable
|
|
if orphaned_labels:
|
|
LOG.warning(
|
|
"Copilot produced disconnected workflow blocks, repairing chain",
|
|
orphaned_labels=sorted(orphaned_labels),
|
|
reachable_labels=sorted(reachable),
|
|
)
|
|
|
|
terminal_label = _find_terminal_label(blocks[0].label, label_to_block, all_labels)
|
|
ordered_orphan_labels = _order_orphaned_blocks(orphaned_labels, label_to_block, all_labels, blocks)
|
|
|
|
if terminal_label and ordered_orphan_labels:
|
|
label_to_block[terminal_label].next_block_label = ordered_orphan_labels[0]
|
|
|
|
# Phase 3: recursively repair nested ForLoopBlockYAML.loop_blocks.
|
|
for block in blocks:
|
|
if isinstance(block, ForLoopBlockYAML) and block.loop_blocks:
|
|
_repair_next_block_label_chain(block.loop_blocks)
|
|
|
|
|
|
def _process_workflow_yaml(
|
|
workflow_id: str,
|
|
workflow_permanent_id: str,
|
|
organization_id: str,
|
|
workflow_yaml: str,
|
|
) -> Workflow:
|
|
parsed_yaml = safe_load_no_dates(workflow_yaml)
|
|
|
|
# Fixing trivial common LLM mistakes
|
|
workflow_definition = parsed_yaml.get("workflow_definition", None)
|
|
if workflow_definition:
|
|
blocks = workflow_definition.get("blocks", [])
|
|
for block in blocks:
|
|
block["title"] = block.get("title", "")
|
|
|
|
workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml)
|
|
|
|
# Post-processing
|
|
for block in workflow_yaml_request.workflow_definition.blocks:
|
|
if isinstance(block, LoginBlockYAML) and not block.navigation_goal:
|
|
block.navigation_goal = DEFAULT_LOGIN_PROMPT
|
|
|
|
workflow_yaml_request.workflow_definition.parameters = [
|
|
p for p in workflow_yaml_request.workflow_definition.parameters if p.parameter_type != ParameterType.OUTPUT
|
|
]
|
|
|
|
_repair_next_block_label_chain(workflow_yaml_request.workflow_definition.blocks)
|
|
|
|
updated_workflow_definition = convert_workflow_definition(
|
|
workflow_definition_yaml=workflow_yaml_request.workflow_definition,
|
|
workflow_id=workflow_id,
|
|
)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
return Workflow(
|
|
workflow_id=workflow_id,
|
|
organization_id=organization_id,
|
|
title=workflow_yaml_request.title or "",
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
version=1,
|
|
is_saved_task=workflow_yaml_request.is_saved_task,
|
|
description=workflow_yaml_request.description,
|
|
workflow_definition=updated_workflow_definition,
|
|
proxy_location=workflow_yaml_request.proxy_location,
|
|
webhook_callback_url=workflow_yaml_request.webhook_callback_url,
|
|
persist_browser_session=workflow_yaml_request.persist_browser_session or False,
|
|
model=workflow_yaml_request.model,
|
|
max_screenshot_scrolls=workflow_yaml_request.max_screenshot_scrolls,
|
|
extra_http_headers=workflow_yaml_request.extra_http_headers,
|
|
run_with=workflow_yaml_request.run_with,
|
|
ai_fallback=workflow_yaml_request.ai_fallback,
|
|
cache_key=workflow_yaml_request.cache_key,
|
|
run_sequentially=workflow_yaml_request.run_sequentially,
|
|
sequential_key=workflow_yaml_request.sequential_key,
|
|
created_at=now,
|
|
modified_at=now,
|
|
)
|
|
|
|
|
|
async def _new_copilot_chat_post(
|
|
request: Request,
|
|
chat_request: WorkflowCopilotChatRequest,
|
|
organization: Organization,
|
|
) -> EventSourceResponse:
|
|
"""ENABLE_WORKFLOW_COPILOT_V2 dispatch target.
|
|
|
|
Runs the openai-agents-SDK copilot (skyvern.forge.sdk.copilot.agent) and
|
|
streams responses in the same SSE shape the frontend consumes. On
|
|
mid-stream failure (HTTPException, LLMProviderError, asyncio.CancelledError,
|
|
or unexpected exception), rolls the workflow definition back to
|
|
``original_workflow`` via ``_restore_workflow_definition`` to avoid leaving
|
|
a half-persisted draft.
|
|
"""
|
|
|
|
async def stream_handler(stream: EventSourceStream) -> None:
|
|
LOG.info(
|
|
"Workflow copilot v2 chat request",
|
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
workflow_run_id=chat_request.workflow_run_id,
|
|
message=chat_request.message,
|
|
workflow_yaml_length=len(chat_request.workflow_yaml),
|
|
organization_id=organization.organization_id,
|
|
)
|
|
|
|
original_workflow: Workflow | None = None
|
|
chat = None
|
|
agent_result: Any = None
|
|
|
|
try:
|
|
await stream.send(
|
|
WorkflowCopilotProcessingUpdate(
|
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
|
status="Processing...",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
|
|
if chat_request.workflow_copilot_chat_id:
|
|
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
|
|
organization_id=organization.organization_id,
|
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
)
|
|
if not chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
|
|
else:
|
|
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
|
|
organization_id=organization.organization_id,
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
)
|
|
|
|
chat_request.workflow_copilot_chat_id = chat.workflow_copilot_chat_id
|
|
|
|
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
)
|
|
global_llm_context = None
|
|
for message in reversed(chat_messages):
|
|
if message.global_llm_context is not None:
|
|
global_llm_context = message.global_llm_context
|
|
break
|
|
|
|
if chat.proposed_workflow and chat.proposed_workflow.get("_copilot_yaml"):
|
|
chat_request.workflow_yaml = chat.proposed_workflow["_copilot_yaml"]
|
|
|
|
block_infos, debug_html = await _get_new_copilot_block_infos(
|
|
organization.organization_id, chat_request.workflow_run_id
|
|
)
|
|
|
|
debug_run_info_text = ""
|
|
if block_infos:
|
|
parts: list[str] = []
|
|
for bi in block_infos:
|
|
block_text = f"Block: {bi.block_label} ({bi.block_type}) — {bi.block_status}"
|
|
if bi.failure_reason:
|
|
block_text += f"\n Failure Reason: {bi.failure_reason}"
|
|
if bi.output:
|
|
block_text += f"\n Output: {bi.output}"
|
|
parts.append(block_text)
|
|
debug_run_info_text = "\n".join(parts)
|
|
if debug_html:
|
|
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_html}"
|
|
|
|
await stream.send(
|
|
WorkflowCopilotProcessingUpdate(
|
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
|
status="Thinking...",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
|
|
# No early exit on disconnect (SKY-8986): the agent runs to
|
|
# completion even after the SSE stream drops so its reply is
|
|
# persisted to the chat history and visible after reconnect.
|
|
|
|
original_workflow = await app.DATABASE.workflows.get_workflow_by_permanent_id(
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
organization_id=organization.organization_id,
|
|
)
|
|
|
|
if not original_workflow:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Workflow not found")
|
|
|
|
chat_request.workflow_id = original_workflow.workflow_id
|
|
|
|
llm_api_handler = (
|
|
await get_llm_handler_for_prompt_type(
|
|
"workflow-copilot", chat_request.workflow_permanent_id, organization.organization_id
|
|
)
|
|
or app.LLM_API_HANDLER
|
|
)
|
|
|
|
api_key = request.headers.get("x-api-key")
|
|
security_rules = app.AGENT_FUNCTION.get_copilot_security_rules()
|
|
|
|
agent_result = await run_copilot_agent(
|
|
stream=stream,
|
|
organization_id=organization.organization_id,
|
|
chat_request=chat_request,
|
|
chat_history=convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
|
global_llm_context=global_llm_context,
|
|
debug_run_info_text=debug_run_info_text,
|
|
llm_api_handler=llm_api_handler,
|
|
api_key=api_key,
|
|
security_rules=security_rules,
|
|
)
|
|
|
|
user_response = agent_result.user_response
|
|
updated_workflow = agent_result.updated_workflow
|
|
updated_global_llm_context = agent_result.global_llm_context
|
|
|
|
# Persist rollback / proposed-workflow state and the chat
|
|
# messages regardless of whether the SSE client is still
|
|
# connected: the user needs to see the reply on reconnect.
|
|
# SKY-8986: client disconnect used to short-circuit this block
|
|
# and leave the chat history without the AI response.
|
|
#
|
|
# SKY-9143: restore runs outside the auto_accept wrapper so
|
|
# auto-accept turns that ended without a viable proposal still
|
|
# roll back a mid-turn _update_workflow write. The Accept/Reject
|
|
# panel state below stays gated on auto_accept — the frontend
|
|
# applies proposals via applyWorkflowUpdate when auto-accept is
|
|
# on.
|
|
restored = _should_restore_persisted_workflow(chat.auto_accept, agent_result)
|
|
if restored:
|
|
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
|
|
|
if chat.auto_accept is not True:
|
|
if updated_workflow:
|
|
proposed_data = updated_workflow.model_dump(mode="json")
|
|
if agent_result.workflow_yaml:
|
|
proposed_data["_copilot_yaml"] = agent_result.workflow_yaml
|
|
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
proposed_workflow=proposed_data,
|
|
)
|
|
elif (
|
|
restored or getattr(agent_result, "clear_proposed_workflow", False)
|
|
) and chat.proposed_workflow is not None:
|
|
# Null any previously-persisted proposed_workflow so a
|
|
# page reload does not resurrect a stale Accept/Reject
|
|
# card next to an assistant message that just explained
|
|
# why no verified proposal is available. Covers:
|
|
# * feasibility-gate fast-path clarifications, and
|
|
# * SKY-9143 strict-gate turns where a mid-turn draft was
|
|
# rolled back (``restored=True``).
|
|
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
proposed_workflow=None,
|
|
)
|
|
|
|
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
sender=WorkflowCopilotChatSender.USER,
|
|
content=chat_request.message,
|
|
)
|
|
|
|
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
sender=WorkflowCopilotChatSender.AI,
|
|
content=user_response,
|
|
global_llm_context=updated_global_llm_context,
|
|
)
|
|
|
|
await stream.send(
|
|
WorkflowCopilotStreamResponseUpdate(
|
|
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
message=user_response,
|
|
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
|
|
response_time=assistant_message.created_at,
|
|
total_tokens=getattr(agent_result, "total_tokens", None),
|
|
response_type=getattr(agent_result, "response_type", "REPLY"),
|
|
)
|
|
)
|
|
except HTTPException as exc:
|
|
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
|
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error=exc.detail,
|
|
)
|
|
)
|
|
except LLMProviderError as exc:
|
|
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
|
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
|
LOG.error(
|
|
"LLM provider error (copilot v2)",
|
|
organization_id=organization.organization_id,
|
|
error=str(exc),
|
|
exc_info=True,
|
|
)
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error="Failed to process your request. Please try again.",
|
|
)
|
|
)
|
|
except asyncio.CancelledError:
|
|
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
|
await asyncio.shield(_restore_workflow_definition(original_workflow, organization.organization_id))
|
|
LOG.info(
|
|
"Client disconnected during workflow copilot v2",
|
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
)
|
|
except Exception as exc:
|
|
if chat is not None and _should_restore_persisted_workflow(chat.auto_accept, agent_result):
|
|
await _restore_workflow_definition(original_workflow, organization.organization_id)
|
|
LOG.error(
|
|
"Unexpected error in workflow copilot v2",
|
|
organization_id=organization.organization_id,
|
|
error=str(exc),
|
|
exc_info=True,
|
|
)
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error="An error occurred. Please try again.",
|
|
)
|
|
)
|
|
|
|
return FastAPIEventSourceStream.create(request, stream_handler)
|
|
|
|
|
|
COPILOT_V2_FLAG_KEY = "ENABLE_WORKFLOW_COPILOT_V2"
|
|
|
|
|
|
async def _should_use_copilot_v2(organization: Organization, workflow_permanent_id: str) -> bool:
|
|
if settings.ENABLE_WORKFLOW_COPILOT_V2:
|
|
return True
|
|
try:
|
|
# distinct_id is the org (not a run id) because this gate is an org-sticky rollout:
|
|
# copilot chat may not have a stable run at dispatch time, and we want each org to
|
|
# see the same path across sessions. Contrast with backend.md's default of run-level
|
|
# ids for per-run randomized experiments.
|
|
return await app.EXPERIMENTATION_PROVIDER.is_feature_enabled_cached(
|
|
COPILOT_V2_FLAG_KEY,
|
|
distinct_id=organization.organization_id,
|
|
properties={"organization_id": organization.organization_id},
|
|
)
|
|
except Exception:
|
|
LOG.exception(
|
|
"Failed to evaluate copilot-v2 feature flag; falling back to legacy copilot",
|
|
organization_id=organization.organization_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
)
|
|
return False
|
|
|
|
|
|
@base_router.post("/workflow/copilot/chat-post", include_in_schema=False)
|
|
async def workflow_copilot_chat_post(
|
|
request: Request,
|
|
chat_request: WorkflowCopilotChatRequest,
|
|
organization: Organization = Depends(org_auth_service.get_current_org),
|
|
) -> EventSourceResponse:
|
|
if await _should_use_copilot_v2(organization, chat_request.workflow_permanent_id):
|
|
return await _new_copilot_chat_post(request, chat_request, organization)
|
|
|
|
async def stream_handler(stream: EventSourceStream) -> None:
|
|
LOG.info(
|
|
"Workflow copilot chat request",
|
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
workflow_run_id=chat_request.workflow_run_id,
|
|
message=chat_request.message,
|
|
workflow_yaml_length=len(chat_request.workflow_yaml),
|
|
organization_id=organization.organization_id,
|
|
)
|
|
|
|
try:
|
|
await stream.send(
|
|
WorkflowCopilotProcessingUpdate(
|
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
|
status="Processing...",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
|
|
if chat_request.workflow_copilot_chat_id:
|
|
chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id(
|
|
organization_id=organization.organization_id,
|
|
workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id,
|
|
)
|
|
if not chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
if chat_request.workflow_permanent_id != chat.workflow_permanent_id:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Wrong workflow permanent ID")
|
|
else:
|
|
chat = await app.DATABASE.workflow_params.create_workflow_copilot_chat(
|
|
organization_id=organization.organization_id,
|
|
workflow_permanent_id=chat_request.workflow_permanent_id,
|
|
)
|
|
|
|
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
)
|
|
global_llm_context = None
|
|
for message in reversed(chat_messages):
|
|
if message.global_llm_context is not None:
|
|
global_llm_context = message.global_llm_context
|
|
break
|
|
|
|
debug_run_info = await _get_debug_run_info(organization.organization_id, chat_request.workflow_run_id)
|
|
|
|
# Format debug run info for prompt
|
|
debug_run_info_text = ""
|
|
if debug_run_info:
|
|
debug_run_info_text = f"Block Label: {debug_run_info.block_label}"
|
|
debug_run_info_text += f" Block Type: {debug_run_info.block_type}"
|
|
debug_run_info_text += f" Status: {debug_run_info.block_status}"
|
|
if debug_run_info.failure_reason:
|
|
debug_run_info_text += f"\nFailure Reason: {debug_run_info.failure_reason}"
|
|
if debug_run_info.html:
|
|
debug_run_info_text += f"\n\nVisible Elements Tree (HTML):\n{debug_run_info.html}"
|
|
|
|
await stream.send(
|
|
WorkflowCopilotProcessingUpdate(
|
|
type=WorkflowCopilotStreamMessageType.PROCESSING_UPDATE,
|
|
status="Thinking...",
|
|
timestamp=datetime.now(timezone.utc),
|
|
)
|
|
)
|
|
|
|
# SKY-8986: do not short-circuit on client disconnect. The LLM
|
|
# call and the DB persistence below must complete so the reply
|
|
# is in the chat history when the user reconnects.
|
|
user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm(
|
|
stream,
|
|
organization.organization_id,
|
|
chat_request,
|
|
convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]),
|
|
global_llm_context,
|
|
debug_run_info_text,
|
|
)
|
|
|
|
if updated_workflow and chat.auto_accept is not True:
|
|
await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
proposed_workflow=updated_workflow.model_dump(mode="json"),
|
|
)
|
|
|
|
await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
sender=WorkflowCopilotChatSender.USER,
|
|
content=chat_request.message,
|
|
)
|
|
|
|
assistant_message = await app.DATABASE.workflow_params.create_workflow_copilot_chat_message(
|
|
organization_id=chat.organization_id,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
sender=WorkflowCopilotChatSender.AI,
|
|
content=user_response,
|
|
global_llm_context=updated_global_llm_context,
|
|
)
|
|
|
|
await stream.send(
|
|
WorkflowCopilotStreamResponseUpdate(
|
|
type=WorkflowCopilotStreamMessageType.RESPONSE,
|
|
workflow_copilot_chat_id=chat.workflow_copilot_chat_id,
|
|
message=user_response,
|
|
updated_workflow=updated_workflow.model_dump(mode="json") if updated_workflow else None,
|
|
response_time=assistant_message.created_at,
|
|
)
|
|
)
|
|
except HTTPException as exc:
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error=exc.detail,
|
|
)
|
|
)
|
|
except LLMProviderError as exc:
|
|
LOG.error(
|
|
"LLM provider error",
|
|
organization_id=organization.organization_id,
|
|
error=str(exc),
|
|
exc_info=True,
|
|
)
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error="Failed to process your request. Please try again.",
|
|
)
|
|
)
|
|
except Exception as exc:
|
|
LOG.error(
|
|
"Unexpected error in workflow copilot",
|
|
organization_id=organization.organization_id,
|
|
error=str(exc),
|
|
exc_info=True,
|
|
)
|
|
await stream.send(
|
|
WorkflowCopilotStreamErrorUpdate(
|
|
type=WorkflowCopilotStreamMessageType.ERROR,
|
|
error="An error occurred. Please try again.",
|
|
)
|
|
)
|
|
|
|
return FastAPIEventSourceStream.create(request, stream_handler)
|
|
|
|
|
|
@base_router.get("/workflow/copilot/chat-history", include_in_schema=False)
|
|
async def workflow_copilot_chat_history(
|
|
workflow_permanent_id: str,
|
|
organization: Organization = Depends(org_auth_service.get_current_org),
|
|
) -> WorkflowCopilotChatHistoryResponse:
|
|
latest_chat = await app.DATABASE.workflow_params.get_latest_workflow_copilot_chat(
|
|
organization_id=organization.organization_id,
|
|
workflow_permanent_id=workflow_permanent_id,
|
|
)
|
|
if latest_chat:
|
|
chat_messages = await app.DATABASE.workflow_params.get_workflow_copilot_chat_messages(
|
|
latest_chat.workflow_copilot_chat_id
|
|
)
|
|
else:
|
|
chat_messages = []
|
|
return WorkflowCopilotChatHistoryResponse(
|
|
workflow_copilot_chat_id=latest_chat.workflow_copilot_chat_id if latest_chat else None,
|
|
chat_history=convert_to_history_messages(chat_messages),
|
|
proposed_workflow=latest_chat.proposed_workflow if latest_chat else None,
|
|
auto_accept=latest_chat.auto_accept if latest_chat else None,
|
|
)
|
|
|
|
|
|
@base_router.post(
|
|
"/workflow/copilot/clear-proposed-workflow", include_in_schema=False, status_code=status.HTTP_204_NO_CONTENT
|
|
)
|
|
async def workflow_copilot_clear_proposed_workflow(
|
|
clear_request: WorkflowCopilotClearProposedWorkflowRequest,
|
|
organization: Organization = Depends(org_auth_service.get_current_org),
|
|
) -> None:
|
|
updated_chat = await app.DATABASE.workflow_params.update_workflow_copilot_chat(
|
|
organization_id=organization.organization_id,
|
|
workflow_copilot_chat_id=clear_request.workflow_copilot_chat_id,
|
|
proposed_workflow=None,
|
|
auto_accept=clear_request.auto_accept,
|
|
)
|
|
if not updated_chat:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found")
|
|
|
|
|
|
def convert_to_history_messages(
|
|
messages: list[WorkflowCopilotChatMessage],
|
|
) -> list[WorkflowCopilotChatHistoryMessage]:
|
|
return [
|
|
WorkflowCopilotChatHistoryMessage(
|
|
sender=message.sender,
|
|
content=message.content,
|
|
created_at=message.created_at,
|
|
)
|
|
for message in messages
|
|
]
|
|
|
|
|
|
@base_router.post("/workflow/copilot/convert-yaml-to-blocks", include_in_schema=False)
|
|
async def workflow_copilot_convert_yaml_to_blocks(
|
|
request: WorkflowYAMLConversionRequest,
|
|
organization: Organization = Depends(org_auth_service.get_current_org),
|
|
) -> WorkflowYAMLConversionResponse:
|
|
"""
|
|
Convert workflow definition YAML to blocks format for comparison view.
|
|
This endpoint is used by the frontend to convert YAML to the proper blocks structure
|
|
that the comparison panel expects.
|
|
"""
|
|
try:
|
|
parsed_yaml = safe_load_no_dates(request.workflow_definition_yaml)
|
|
workflow_definition_yaml = WorkflowDefinitionYAML.model_validate(parsed_yaml)
|
|
|
|
_repair_next_block_label_chain(workflow_definition_yaml.blocks)
|
|
|
|
workflow_definition = convert_workflow_definition(
|
|
workflow_definition_yaml=workflow_definition_yaml,
|
|
workflow_id=request.workflow_id,
|
|
)
|
|
|
|
return WorkflowYAMLConversionResponse(workflow_definition=workflow_definition.model_dump(mode="json"))
|
|
except (yaml.YAMLError, ValidationError, BaseWorkflowHTTPException) as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"Failed to convert workflow YAML: {str(e)}",
|
|
)
|