diff --git a/alembic/versions/2026_04_26_0338-bd362c15b74b_add_copilot_attribution_columns.py b/alembic/versions/2026_04_26_0338-bd362c15b74b_add_copilot_attribution_columns.py new file mode 100644 index 000000000..a5492cdb4 --- /dev/null +++ b/alembic/versions/2026_04_26_0338-bd362c15b74b_add_copilot_attribution_columns.py @@ -0,0 +1,31 @@ +"""add_copilot_attribution_columns + +Revision ID: bd362c15b74b +Revises: 70b5f11e3655 +Create Date: 2026-04-26T03:38:25.486705+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "bd362c15b74b" +down_revision: Union[str, None] = "70b5f11e3655" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("workflows", sa.Column("created_by", sa.String(), nullable=True)) + op.add_column("workflows", sa.Column("edited_by", sa.String(), nullable=True)) + op.add_column("workflow_runs", sa.Column("copilot_session_id", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("workflow_runs", "copilot_session_id") + op.drop_column("workflows", "edited_by") + op.drop_column("workflows", "created_by") diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index e8f3445ff..a7c07425d 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -104,6 +104,9 @@ def _enrich_llm_span( span.set_attribute("gen_ai.usage.reasoning_tokens", reasoning_tokens) span.set_attribute("gen_ai.usage.cached_tokens", cached_tokens) span.set_attribute("gen_ai.usage.cost", llm_cost) + ctx = skyvern_context.current() + if ctx is not None and ctx.copilot_session_id is not None: + span.set_attribute("copilot.session_id", ctx.copilot_session_id) span.add_event( LLM_REQUEST_COMPLETED_EVENT, attributes={ diff --git a/skyvern/forge/sdk/copilot/agent.py b/skyvern/forge/sdk/copilot/agent.py index 2cc00c8a2..303e7e670 100644 --- a/skyvern/forge/sdk/copilot/agent.py +++ b/skyvern/forge/sdk/copilot/agent.py @@ -389,6 +389,7 @@ async def run_copilot_agent( stream=stream, api_key=api_key, user_message=chat_request.message, + workflow_copilot_chat_id=chat_request.workflow_copilot_chat_id, ) model_name, run_config, llm_key, supports_vision = resolve_model_config(llm_api_handler) diff --git a/skyvern/forge/sdk/copilot/attribution.py b/skyvern/forge/sdk/copilot/attribution.py new file mode 100644 index 000000000..6442be92f --- /dev/null +++ b/skyvern/forge/sdk/copilot/attribution.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from typing import Any + +import structlog + +from skyvern.forge import app +from skyvern.forge.sdk.db._sentinels import _UNSET +from skyvern.forge.sdk.workflow.models.workflow import Workflow + +LOG = structlog.get_logger() + + +def is_copilot_born_initial_write(workflow: Workflow | None) -> bool: + if workflow is None: + return False + if workflow.created_by is not None: + return False + if workflow.version != 1: + return False + return len(workflow.workflow_definition.blocks) == 0 + + +async def resolve_copilot_created_by_stamp(workflow_id: str, organization_id: str) -> Any: + """Return ``"copilot"`` for a copilot-born initial write, ``_UNSET`` otherwise. + + ``_UNSET`` (not ``None``) so the repo's omit-vs-clear sentinel preserves prior values. + """ + try: + workflow = await app.WORKFLOW_SERVICE.get_workflow( + workflow_id=workflow_id, + organization_id=organization_id, + ) + except Exception: + LOG.warning( + "Failed pre-update workflow read for copilot attribution; skipping created_by stamp", + workflow_id=workflow_id, + exc_info=True, + ) + return _UNSET + try: + if is_copilot_born_initial_write(workflow): + return "copilot" + except Exception: + LOG.warning( + "is_copilot_born_initial_write raised; skipping created_by stamp", + workflow_id=workflow_id, + exc_info=True, + ) + return _UNSET diff --git a/skyvern/forge/sdk/copilot/context.py b/skyvern/forge/sdk/copilot/context.py index 91aeae717..ae8304ff8 100644 --- a/skyvern/forge/sdk/copilot/context.py +++ b/skyvern/forge/sdk/copilot/context.py @@ -142,6 +142,8 @@ class CopilotContext(AgentContext): avoid drift. """ + workflow_copilot_chat_id: str | None = None + # Enforcement state navigate_called: bool = False observation_after_navigate: bool = False diff --git a/skyvern/forge/sdk/copilot/tools.py b/skyvern/forge/sdk/copilot/tools.py index a6eb61b23..d64757e61 100644 --- a/skyvern/forge/sdk/copilot/tools.py +++ b/skyvern/forge/sdk/copilot/tools.py @@ -20,6 +20,7 @@ from pydantic import ValidationError from skyvern.forge import app from skyvern.forge.failure_classifier import classify_from_failure_reason from skyvern.forge.sdk.artifact.models import ArtifactType +from skyvern.forge.sdk.copilot.attribution import resolve_copilot_created_by_stamp from skyvern.forge.sdk.copilot.block_goal_wrapping import wrap_block_goals from skyvern.forge.sdk.copilot.context import CopilotContext from skyvern.forge.sdk.copilot.failure_tracking import ( @@ -524,6 +525,9 @@ async def _update_workflow(params: dict[str, Any], ctx: AgentContext) -> dict[st organization_id=ctx.organization_id, workflow_yaml=workflow_yaml, ) + + created_by_stamp = await resolve_copilot_created_by_stamp(ctx.workflow_id, ctx.organization_id) + await app.WORKFLOW_SERVICE.update_workflow_definition( workflow_id=ctx.workflow_id, organization_id=ctx.organization_id, @@ -541,6 +545,8 @@ async def _update_workflow(params: dict[str, Any], ctx: AgentContext) -> dict[st cache_key=workflow.cache_key, run_sequentially=workflow.run_sequentially, sequential_key=workflow.sequential_key, + created_by=created_by_stamp, + edited_by="copilot", ) ctx.workflow_yaml = workflow_yaml return { @@ -1120,6 +1126,7 @@ async def _run_blocks_and_collect_debug( version=None, max_steps=None, request_id=None, + copilot_session_id=ctx.workflow_copilot_chat_id, ) from skyvern.utils.files import initialize_skyvern_state_file diff --git a/skyvern/forge/sdk/copilot/tracing_setup.py b/skyvern/forge/sdk/copilot/tracing_setup.py index 1bda417f8..be8ad6755 100644 --- a/skyvern/forge/sdk/copilot/tracing_setup.py +++ b/skyvern/forge/sdk/copilot/tracing_setup.py @@ -18,6 +18,7 @@ import structlog # Reuse the HTTP-logging redactor so trace-side and SSE-side redaction share # one exact-match sensitive-key policy. from skyvern.forge.request_logging import redact_sensitive_fields +from skyvern.forge.sdk.core import skyvern_context LOG = structlog.get_logger() @@ -225,6 +226,10 @@ def _patch_agent_span_attributes() -> None: # trace backend. attrs["input"] = "[redacted: serialization error]" LOG.warning("Copilot tool-call input redaction failed", error=repr(exc)) + ctx = skyvern_context.current() + if ctx is not None and ctx.copilot_session_id is not None: + if isinstance(span_data, (AgentSpanData, GenerationSpanData, FunctionSpanData)): + attrs["copilot.session_id"] = ctx.copilot_session_id return attrs _oai_mod.attributes_from_span_data = _patched diff --git a/skyvern/forge/sdk/core/skyvern_context.py b/skyvern/forge/sdk/core/skyvern_context.py index feb6bad37..527abe084 100644 --- a/skyvern/forge/sdk/core/skyvern_context.py +++ b/skyvern/forge/sdk/core/skyvern_context.py @@ -28,6 +28,7 @@ class SkyvernContext: browser_session_id: str | None = None tz_info: ZoneInfo | None = None run_id: str | None = None + copilot_session_id: str | None = None totp_codes: dict[str, str | None] = field(default_factory=dict) log: list[dict] = field(default_factory=list) hashed_href_map: dict[str, str] = field(default_factory=dict) @@ -101,7 +102,7 @@ class SkyvernContext: proactive_captcha_task_ids: set[str] = field(default_factory=set) def __repr__(self) -> str: - return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id})" + return f"SkyvernContext(request_id={self.request_id}, organization_id={self.organization_id}, task_id={self.task_id}, step_id={self.step_id}, workflow_id={self.workflow_id}, workflow_run_id={self.workflow_run_id}, task_v2_id={self.task_v2_id}, max_steps_override={self.max_steps_override}, run_id={self.run_id}, copilot_session_id={self.copilot_session_id})" def __str__(self) -> str: return self.__repr__() diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index 87df33a87..b4b05b2e2 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -323,6 +323,8 @@ class WorkflowModel(SoftDeleteMixin, Base): sequential_key = Column(String, nullable=True) folder_id = Column(String, ForeignKey("folders.folder_id", ondelete="SET NULL"), nullable=True) import_error = Column(String, nullable=True) # Error message if import failed + created_by = Column(String, nullable=True) + edited_by = Column(String, nullable=True) created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column( @@ -436,6 +438,7 @@ class WorkflowRunModel(Base): ignore_inherited_workflow_system_prompt = Column( Boolean, nullable=False, default=False, server_default=sqlalchemy.false() ) + copilot_session_id = Column(String, nullable=True) queued_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) diff --git a/skyvern/forge/sdk/db/repositories/workflow_runs.py b/skyvern/forge/sdk/db/repositories/workflow_runs.py index 1864b27c8..05f4a5b5c 100644 --- a/skyvern/forge/sdk/db/repositories/workflow_runs.py +++ b/skyvern/forge/sdk/db/repositories/workflow_runs.py @@ -164,6 +164,7 @@ class WorkflowRunsRepository(BaseRepository): trigger_type: WorkflowRunTriggerType | None = None, workflow_schedule_id: str | None = None, ignore_inherited_workflow_system_prompt: bool = False, + copilot_session_id: str | None = None, ) -> WorkflowRun: async with self.Session() as session: kwargs: dict[str, Any] = {} @@ -192,6 +193,7 @@ class WorkflowRunsRepository(BaseRepository): trigger_type=trigger_type.value if trigger_type else None, workflow_schedule_id=workflow_schedule_id, ignore_inherited_workflow_system_prompt=ignore_inherited_workflow_system_prompt, + copilot_session_id=copilot_session_id, **kwargs, ) session.add(workflow_run) diff --git a/skyvern/forge/sdk/db/repositories/workflows.py b/skyvern/forge/sdk/db/repositories/workflows.py index fc3db0280..3ff63b70b 100644 --- a/skyvern/forge/sdk/db/repositories/workflows.py +++ b/skyvern/forge/sdk/db/repositories/workflows.py @@ -99,6 +99,8 @@ class WorkflowsRepository(BaseRepository): run_sequentially: bool = False, sequential_key: str | None = None, folder_id: str | None = None, + created_by: str | None = None, + edited_by: str | None = None, ) -> Workflow: async with self.Session() as session: workflow = WorkflowModel( @@ -125,6 +127,8 @@ class WorkflowsRepository(BaseRepository): run_sequentially=run_sequentially, sequential_key=sequential_key, folder_id=folder_id, + created_by=created_by, + edited_by=edited_by, ) if workflow_permanent_id: workflow.workflow_permanent_id = workflow_permanent_id @@ -593,6 +597,8 @@ class WorkflowsRepository(BaseRepository): ai_fallback: bool | None = None, run_sequentially: bool | None = None, sequential_key: str | None | object = _UNSET, + created_by: str | None | object = _UNSET, + edited_by: str | None | object = _UNSET, ) -> Workflow: async with self.Session() as session: get_workflow_query = exclude_deleted( @@ -635,6 +641,10 @@ class WorkflowsRepository(BaseRepository): workflow.run_sequentially = run_sequentially if sequential_key is not _UNSET: workflow.sequential_key = sequential_key + if created_by is not _UNSET: + workflow.created_by = cast(str | None, created_by) + if edited_by is not _UNSET: + workflow.edited_by = cast(str | None, edited_by) await session.commit() await session.refresh(workflow) is_template = ( @@ -675,6 +685,8 @@ class WorkflowsRepository(BaseRepository): ai_fallback: bool | None = None, run_sequentially: bool | None = None, sequential_key: str | None | object = _UNSET, + created_by: str | None | object = _UNSET, + edited_by: str | None | object = _UNSET, ) -> Workflow: """One-session, one-commit update of the workflow row + definition-parameter rows. @@ -753,6 +765,10 @@ class WorkflowsRepository(BaseRepository): workflow.run_sequentially = run_sequentially if sequential_key is not _UNSET: workflow.sequential_key = sequential_key + if created_by is not _UNSET: + workflow.created_by = cast(str | None, created_by) + if edited_by is not _UNSET: + workflow.edited_by = cast(str | None, edited_by) await session.commit() await session.refresh(workflow) diff --git a/skyvern/forge/sdk/db/utils.py b/skyvern/forge/sdk/db/utils.py index ae164c876..8cf2dfb2c 100644 --- a/skyvern/forge/sdk/db/utils.py +++ b/skyvern/forge/sdk/db/utils.py @@ -441,6 +441,8 @@ def convert_to_workflow( sequential_key=workflow_model.sequential_key, folder_id=workflow_model.folder_id, import_error=workflow_model.import_error, + created_by=workflow_model.created_by, + edited_by=workflow_model.edited_by, ) @@ -491,6 +493,7 @@ def convert_to_workflow_run( workflow_schedule_id=workflow_run_model.workflow_schedule_id, failure_category=workflow_run_model.failure_category, ignore_inherited_workflow_system_prompt=workflow_run_model.ignore_inherited_workflow_system_prompt, + copilot_session_id=workflow_run_model.copilot_session_id, ) diff --git a/skyvern/forge/sdk/routes/workflow_copilot.py b/skyvern/forge/sdk/routes/workflow_copilot.py index cf51467c4..9465a4ab9 100644 --- a/skyvern/forge/sdk/routes/workflow_copilot.py +++ b/skyvern/forge/sdk/routes/workflow_copilot.py @@ -1,9 +1,10 @@ import asyncio import time +from contextlib import contextmanager from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path -from typing import Any +from typing import Any, Iterator import structlog import yaml @@ -19,12 +20,15 @@ from skyvern.forge.sdk.api.llm.api_handler import LLMAPIHandler from skyvern.forge.sdk.api.llm.exceptions import LLMProviderError from skyvern.forge.sdk.artifact.models import Artifact, ArtifactType from skyvern.forge.sdk.copilot.agent import run_copilot_agent +from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write from skyvern.forge.sdk.copilot.output_utils import truncate_output +from skyvern.forge.sdk.core import skyvern_context from skyvern.forge.sdk.experimentation.llm_prompt_config import get_llm_handler_for_prompt_type from skyvern.forge.sdk.routes.event_source_stream import EventSourceStream, FastAPIEventSourceStream from skyvern.forge.sdk.routes.routers import base_router from skyvern.forge.sdk.schemas.organizations import Organization from skyvern.forge.sdk.schemas.workflow_copilot import ( + WorkflowCopilotApplyProposedWorkflowRequest, WorkflowCopilotChatHistoryMessage, WorkflowCopilotChatHistoryResponse, WorkflowCopilotChatMessage, @@ -61,6 +65,21 @@ CHAT_HISTORY_CONTEXT_MESSAGES = 10 LOG = structlog.get_logger() +@contextmanager +def bind_copilot_session_id(chat_id: str | None) -> Iterator[None]: + # In-place mutation (not scoped()) preserves request-scoped fields the FastAPI middleware wrote. + ctx = skyvern_context.current() + if ctx is None or chat_id is None: + yield + return + prev = ctx.copilot_session_id + ctx.copilot_session_id = chat_id + try: + yield + finally: + ctx.copilot_session_id = prev + + @dataclass(frozen=True) class RunInfo: block_label: str | None @@ -131,12 +150,15 @@ async def _restore_workflow_definition(original_workflow: Workflow | None, organ if not original_workflow: return try: + # Forward attribution so rollback reverts it alongside the definition. await app.WORKFLOW_SERVICE.update_workflow_definition( workflow_id=original_workflow.workflow_id, organization_id=organization_id, title=original_workflow.title, description=original_workflow.description, workflow_definition=original_workflow.workflow_definition, + created_by=original_workflow.created_by, + edited_by=original_workflow.edited_by, ) except Exception: LOG.warning( @@ -675,20 +697,16 @@ def _repair_next_block_label_chain(blocks: list[BlockYAML]) -> None: _repair_next_block_label_chain(block.loop_blocks) -def _process_workflow_yaml( - workflow_id: str, - workflow_permanent_id: str, - organization_id: str, - workflow_yaml: str, -) -> Workflow: +def _normalize_copilot_yaml(workflow_yaml: str) -> WorkflowCreateYAMLRequest: parsed_yaml = safe_load_no_dates(workflow_yaml) - # Fixing trivial common LLM mistakes - workflow_definition = parsed_yaml.get("workflow_definition", None) - if workflow_definition: - blocks = workflow_definition.get("blocks", []) - for block in blocks: - block["title"] = block.get("title", "") + # Fixing trivial common LLM mistakes; non-dict YAML falls through to model_validate. + if isinstance(parsed_yaml, dict): + workflow_definition = parsed_yaml.get("workflow_definition", None) + if workflow_definition: + blocks = workflow_definition.get("blocks", []) or [] + for block in blocks: + block["title"] = block.get("title", "") workflow_yaml_request = WorkflowCreateYAMLRequest.model_validate(parsed_yaml) @@ -703,6 +721,17 @@ def _process_workflow_yaml( _repair_next_block_label_chain(workflow_yaml_request.workflow_definition.blocks) + return workflow_yaml_request + + +def _process_workflow_yaml( + workflow_id: str, + workflow_permanent_id: str, + organization_id: str, + workflow_yaml: str, +) -> Workflow: + workflow_yaml_request = _normalize_copilot_yaml(workflow_yaml) + updated_workflow_definition = convert_workflow_definition( workflow_definition_yaml=workflow_yaml_request.workflow_definition, workflow_id=workflow_id, @@ -852,17 +881,18 @@ async def _new_copilot_chat_post( api_key = request.headers.get("x-api-key") security_rules = app.AGENT_FUNCTION.get_copilot_security_rules() - agent_result = await run_copilot_agent( - stream=stream, - organization_id=organization.organization_id, - chat_request=chat_request, - chat_history=convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]), - global_llm_context=global_llm_context, - debug_run_info_text=debug_run_info_text, - llm_api_handler=llm_api_handler, - api_key=api_key, - security_rules=security_rules, - ) + with bind_copilot_session_id(chat.workflow_copilot_chat_id): + agent_result = await run_copilot_agent( + stream=stream, + organization_id=organization.organization_id, + chat_request=chat_request, + chat_history=convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]), + global_llm_context=global_llm_context, + debug_run_info_text=debug_run_info_text, + llm_api_handler=llm_api_handler, + api_key=api_key, + security_rules=security_rules, + ) user_response = agent_result.user_response updated_workflow = agent_result.updated_workflow @@ -1094,14 +1124,15 @@ async def workflow_copilot_chat_post( # SKY-8986: do not short-circuit on client disconnect. The LLM # call and the DB persistence below must complete so the reply # is in the chat history when the user reconnects. - user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm( - stream, - organization.organization_id, - chat_request, - convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]), - global_llm_context, - debug_run_info_text, - ) + with bind_copilot_session_id(chat.workflow_copilot_chat_id): + user_response, updated_workflow, updated_global_llm_context = await copilot_call_llm( + stream, + organization.organization_id, + chat_request, + convert_to_history_messages(chat_messages[-CHAT_HISTORY_CONTEXT_MESSAGES:]), + global_llm_context, + debug_run_info_text, + ) if updated_workflow and chat.auto_accept is not True: await app.DATABASE.workflow_params.update_workflow_copilot_chat( @@ -1217,6 +1248,80 @@ async def workflow_copilot_clear_proposed_workflow( raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") +@base_router.post("/workflow/copilot/apply-proposed-workflow", include_in_schema=False) +async def workflow_copilot_apply_proposed_workflow( + apply_request: WorkflowCopilotApplyProposedWorkflowRequest, + organization: Organization = Depends(org_auth_service.get_current_org), +) -> Workflow: + """Accept a copilot proposal: stamp v1, write a new copilot-attributed version, clear the proposal.""" + chat = await app.DATABASE.workflow_params.get_workflow_copilot_chat_by_id( + organization_id=organization.organization_id, + workflow_copilot_chat_id=apply_request.workflow_copilot_chat_id, + ) + if chat is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Chat not found") + + proposal = chat.proposed_workflow + if not proposal: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No proposed workflow to apply") + + copilot_yaml = proposal.get("_copilot_yaml") if isinstance(proposal, dict) else None + if not copilot_yaml: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Proposed workflow has no copilot YAML to apply", + ) + + try: + yaml_request = _normalize_copilot_yaml(copilot_yaml) + except (yaml.YAMLError, ValidationError) as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Proposed copilot YAML is invalid: {e}", + ) + + current_workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( + workflow_permanent_id=chat.workflow_permanent_id, + organization_id=organization.organization_id, + ) + created_by_stamp = "copilot" if is_copilot_born_initial_write(current_workflow) else None + + if created_by_stamp == "copilot" and current_workflow is not None: + # Stamp v1 too so MIN(created_at)-per-WPID queries see copilot-born. + await app.WORKFLOW_SERVICE.update_workflow_definition( + workflow_id=current_workflow.workflow_id, + organization_id=organization.organization_id, + created_by="copilot", + edited_by="copilot", + ) + + new_workflow = await app.WORKFLOW_SERVICE.create_workflow_from_request( + organization=organization, + request=yaml_request, + workflow_permanent_id=chat.workflow_permanent_id, + created_by=created_by_stamp, + edited_by="copilot", + ) + + try: + # Best-effort: a 500 here would invite a retry that creates a duplicate version. + await app.DATABASE.workflow_params.update_workflow_copilot_chat( + organization_id=organization.organization_id, + workflow_copilot_chat_id=chat.workflow_copilot_chat_id, + proposed_workflow=None, + auto_accept=apply_request.auto_accept, + ) + except Exception: + LOG.warning( + "Failed to clear copilot proposal after applying it; new workflow version was created", + workflow_copilot_chat_id=chat.workflow_copilot_chat_id, + new_workflow_id=new_workflow.workflow_id, + exc_info=True, + ) + + return new_workflow + + def convert_to_history_messages( messages: list[WorkflowCopilotChatMessage], ) -> list[WorkflowCopilotChatHistoryMessage]: diff --git a/skyvern/forge/sdk/schemas/workflow_copilot.py b/skyvern/forge/sdk/schemas/workflow_copilot.py index 89d5a05fb..cc48593eb 100644 --- a/skyvern/forge/sdk/schemas/workflow_copilot.py +++ b/skyvern/forge/sdk/schemas/workflow_copilot.py @@ -49,6 +49,14 @@ class WorkflowCopilotClearProposedWorkflowRequest(BaseModel): auto_accept: bool = Field(..., description="Whether to auto-accept future workflow updates") +class WorkflowCopilotApplyProposedWorkflowRequest(BaseModel): + workflow_copilot_chat_id: str = Field(..., description="The chat whose proposed workflow should be applied") + auto_accept: bool = Field( + False, + description="If true, flip the chat to auto-accept mode so future turns persist directly without review", + ) + + class WorkflowCopilotChatHistoryMessage(BaseModel): sender: WorkflowCopilotChatSender = Field(..., description="Message sender") content: str = Field(..., description="Message content") diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 9687c9e90..383ba98ab 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -7153,6 +7153,7 @@ class WorkflowTriggerBlock(Block): skyvern_context.SkyvernContext( run_id=parent_context.run_id if parent_context else None, root_workflow_run_id=parent_context.root_workflow_run_id if parent_context else None, + copilot_session_id=parent_context.copilot_session_id if parent_context else None, ) ): try: diff --git a/skyvern/forge/sdk/workflow/models/workflow.py b/skyvern/forge/sdk/workflow/models/workflow.py index 361e59a0b..3f0701929 100644 --- a/skyvern/forge/sdk/workflow/models/workflow.py +++ b/skyvern/forge/sdk/workflow/models/workflow.py @@ -115,6 +115,8 @@ class Workflow(BaseModel): sequential_key: str | None = None folder_id: str | None = None import_error: str | None = None + created_by: str | None = None + edited_by: str | None = None @field_validator("run_with", mode="before") @classmethod @@ -201,6 +203,7 @@ class WorkflowRun(BaseModel): trigger_type: WorkflowRunTriggerType | None = None workflow_schedule_id: str | None = None ignore_inherited_workflow_system_prompt: bool = False + copilot_session_id: str | None = None @field_validator("run_with", mode="before") @classmethod diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index 439563597..c3b0828ab 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -651,6 +651,7 @@ class WorkflowService: trigger_type: WorkflowRunTriggerType | None = None, workflow_schedule_id: str | None = None, ignore_inherited_workflow_system_prompt: bool = False, + copilot_session_id: str | None = None, ) -> WorkflowRun: """ Create a workflow run and its parameters. Validate the workflow and the organization. If there are missing @@ -697,6 +698,16 @@ class WorkflowService: ) workflow_request.ai_fallback = True + # Inherit from ambient context so descendant runs (TriggerWorkflowBlock children) + # carry the parent's chat id forward without per-call plumbing. Resolved here so + # the same value reaches both the DB row and the new SkyvernContext below. + ambient_context: skyvern_context.SkyvernContext | None = skyvern_context.current() + resolved_copilot_session_id = ( + copilot_session_id + if copilot_session_id is not None + else (ambient_context.copilot_session_id if ambient_context else None) + ) + # Create the workflow run and set skyvern context workflow_run = await self.create_workflow_run( workflow_request=workflow_request, @@ -711,6 +722,7 @@ class WorkflowService: trigger_type=trigger_type, workflow_schedule_id=workflow_schedule_id, ignore_inherited_workflow_system_prompt=ignore_inherited_workflow_system_prompt, + copilot_session_id=resolved_copilot_session_id, ) LOG.info( f"Created workflow run {workflow_run.workflow_run_id} for workflow {workflow.workflow_id}", @@ -743,6 +755,7 @@ class WorkflowService: max_steps_override=max_steps_override, max_screenshot_scrolls=workflow_request.max_screenshot_scrolls, loop_internal_state=copy.deepcopy(context.loop_internal_state) if context else None, + copilot_session_id=resolved_copilot_session_id, ) ) @@ -3180,6 +3193,8 @@ class WorkflowService: adaptive_caching: bool = False, code_version: int | None = None, generate_script_on_terminal: bool = False, + created_by: str | None = None, + edited_by: str | None = None, ) -> Workflow: try: return await app.DATABASE.workflows.create_workflow( @@ -3208,6 +3223,8 @@ class WorkflowService: adaptive_caching=adaptive_caching, code_version=code_version, generate_script_on_terminal=generate_script_on_terminal, + created_by=created_by, + edited_by=edited_by, ) except IntegrityError as e: if "uc_org_permanent_id_version" in str(e) and workflow_permanent_id: @@ -3580,6 +3597,8 @@ class WorkflowService: cache_key: str | None = None, run_sequentially: bool | None = None, sequential_key: str | None | object = _UNSET, + created_by: str | None | object = _UNSET, + edited_by: str | None | object = _UNSET, ) -> Workflow: if workflow_definition is not None: updated_workflow = await app.DATABASE.workflows.update_workflow_and_reconcile_definition_params( @@ -3599,6 +3618,8 @@ class WorkflowService: cache_key=cache_key, run_sequentially=run_sequentially, sequential_key=sequential_key, + created_by=created_by, + edited_by=edited_by, ) return updated_workflow @@ -3619,6 +3640,8 @@ class WorkflowService: cache_key=cache_key, run_sequentially=run_sequentially, sequential_key=sequential_key, + created_by=created_by, + edited_by=edited_by, ) return updated_workflow @@ -3865,6 +3888,7 @@ class WorkflowService: trigger_type: WorkflowRunTriggerType | None = None, workflow_schedule_id: str | None = None, ignore_inherited_workflow_system_prompt: bool = False, + copilot_session_id: str | None = None, ) -> WorkflowRun: # validate the browser session or profile id browser_profile_id = workflow_request.browser_profile_id @@ -3956,6 +3980,7 @@ class WorkflowService: trigger_type=trigger_type, workflow_schedule_id=workflow_schedule_id, ignore_inherited_workflow_system_prompt=ignore_inherited_workflow_system_prompt, + copilot_session_id=copilot_session_id, ) async def _update_workflow_run_status( @@ -5220,6 +5245,8 @@ class WorkflowService: request: WorkflowCreateYAMLRequest, workflow_permanent_id: str | None = None, delete_script: bool = True, + created_by: str | None = None, + edited_by: str | None = None, ) -> Workflow: organization_id = organization.organization_id @@ -5288,6 +5315,8 @@ class WorkflowService: if request.code_version is not None else existing_latest_workflow.code_version, generate_script_on_terminal=request.generate_script_on_terminal, + created_by=created_by, + edited_by=edited_by, ) else: # NOTE: it's only potential, as it may be immediately deleted! @@ -5315,6 +5344,8 @@ class WorkflowService: adaptive_caching=request.adaptive_caching, code_version=request.code_version, generate_script_on_terminal=request.generate_script_on_terminal, + created_by=created_by, + edited_by=edited_by, ) # Keeping track of the new workflow id to delete it if an error occurs during the creation process new_workflow_id = potential_workflow.workflow_id diff --git a/skyvern/services/workflow_service.py b/skyvern/services/workflow_service.py index 6c1e793df..c6a1b46c0 100644 --- a/skyvern/services/workflow_service.py +++ b/skyvern/services/workflow_service.py @@ -28,6 +28,7 @@ async def prepare_workflow( parent_workflow_run_id: str | None = None, trigger_type: WorkflowRunTriggerType | None = None, ignore_inherited_workflow_system_prompt: bool = False, + copilot_session_id: str | None = None, ) -> WorkflowRun: """ Prepare a workflow to be run. @@ -49,6 +50,7 @@ async def prepare_workflow( parent_workflow_run_id=parent_workflow_run_id, trigger_type=trigger_type, ignore_inherited_workflow_system_prompt=ignore_inherited_workflow_system_prompt, + copilot_session_id=copilot_session_id, ) workflow = await app.WORKFLOW_SERVICE.get_workflow_by_permanent_id( diff --git a/tests/unit/forge/sdk/db/test_workflow_attribution_columns.py b/tests/unit/forge/sdk/db/test_workflow_attribution_columns.py new file mode 100644 index 000000000..192a5aa13 --- /dev/null +++ b/tests/unit/forge/sdk/db/test_workflow_attribution_columns.py @@ -0,0 +1,237 @@ +"""Regression tests for copilot attribution columns.""" + +from __future__ import annotations + +from typing import Any, AsyncGenerator + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import create_async_engine + +from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.db.agent_db import AgentDB +from skyvern.forge.sdk.db.models import Base + + +@pytest_asyncio.fixture +async def db_engine() -> AsyncGenerator[Any]: + engine = create_async_engine("sqlite+aiosqlite:///:memory:") + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + yield engine + await engine.dispose() + + +@pytest_asyncio.fixture +async def agent_db(db_engine: Any) -> AsyncGenerator[AgentDB]: + yield AgentDB(database_string="sqlite+aiosqlite:///:memory:", debug_enabled=True, db_engine=db_engine) + + +@pytest_asyncio.fixture +async def org_id(agent_db: AgentDB) -> str: + org = await agent_db.organizations.create_organization( + organization_name="Attribution Org", + domain="attribution.test", + ) + return org.organization_id + + +@pytest.mark.asyncio +async def test_create_workflow_without_attribution_defaults_to_none(agent_db: AgentDB, org_id: str) -> None: + workflow = await agent_db.workflows.create_workflow( + title="plain-create", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + ) + assert workflow.created_by is None + assert workflow.edited_by is None + + +@pytest.mark.asyncio +async def test_create_workflow_stamps_attribution_when_passed(agent_db: AgentDB, org_id: str) -> None: + workflow = await agent_db.workflows.create_workflow( + title="copilot-create", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + created_by="copilot", + edited_by="copilot", + ) + assert workflow.created_by == "copilot" + assert workflow.edited_by == "copilot" + + +@pytest.mark.asyncio +async def test_update_workflow_omit_attribution_preserves_stamps(agent_db: AgentDB, org_id: str) -> None: + workflow = await agent_db.workflows.create_workflow( + title="seed", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + created_by="copilot", + edited_by="copilot", + ) + # Omit created_by / edited_by — the repo must NOT touch either column. + await agent_db.workflows.update_workflow( + workflow_id=workflow.workflow_id, + organization_id=org_id, + title="renamed", + ) + reread = await agent_db.workflows.get_workflow( + workflow_id=workflow.workflow_id, + organization_id=org_id, + ) + assert reread is not None + assert reread.created_by == "copilot" + assert reread.edited_by == "copilot" + + +@pytest.mark.asyncio +async def test_update_workflow_explicit_none_clears_attribution(agent_db: AgentDB, org_id: str) -> None: + # _UNSET sentinel distinguishes omit (preserve) from None (clear); rollback relies on this. + workflow = await agent_db.workflows.create_workflow( + title="seed", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + created_by="copilot", + edited_by="copilot", + ) + await agent_db.workflows.update_workflow( + workflow_id=workflow.workflow_id, + organization_id=org_id, + created_by=None, + edited_by=None, + ) + reread = await agent_db.workflows.get_workflow( + workflow_id=workflow.workflow_id, + organization_id=org_id, + ) + assert reread is not None + assert reread.created_by is None + assert reread.edited_by is None + + +@pytest.mark.asyncio +async def test_update_workflow_and_reconcile_explicit_none_clears_attribution(agent_db: AgentDB, org_id: str) -> None: + # Reconcile path must honor the same omit/None semantics as update_workflow. + from skyvern.forge.sdk.workflow.models.workflow import WorkflowDefinition + + workflow = await agent_db.workflows.create_workflow( + title="seed", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + created_by="copilot", + edited_by="copilot", + ) + await agent_db.workflows.update_workflow_and_reconcile_definition_params( + workflow_id=workflow.workflow_id, + organization_id=org_id, + workflow_definition=WorkflowDefinition(parameters=[], blocks=[]), + created_by=None, + edited_by=None, + ) + reread = await agent_db.workflows.get_workflow( + workflow_id=workflow.workflow_id, + organization_id=org_id, + ) + assert reread is not None + assert reread.created_by is None + assert reread.edited_by is None + + +@pytest.mark.asyncio +async def test_create_workflow_run_without_session_id_defaults_to_none(agent_db: AgentDB, org_id: str) -> None: + # No ambient skyvern_context; no explicit param — copilot_session_id stays NULL. + workflow = await agent_db.workflows.create_workflow( + title="wf", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + ) + run = await agent_db.workflow_runs.create_workflow_run( + workflow_permanent_id=workflow.workflow_permanent_id, + workflow_id=workflow.workflow_id, + organization_id=org_id, + ) + assert run.copilot_session_id is None + + +@pytest.mark.asyncio +async def test_create_workflow_run_explicit_session_id_persists(agent_db: AgentDB, org_id: str) -> None: + workflow = await agent_db.workflows.create_workflow( + title="wf", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + ) + run = await agent_db.workflow_runs.create_workflow_run( + workflow_permanent_id=workflow.workflow_permanent_id, + workflow_id=workflow.workflow_id, + organization_id=org_id, + copilot_session_id="chat_abc123", + ) + assert run.copilot_session_id == "chat_abc123" + + +@pytest.mark.asyncio +async def test_create_workflow_run_ignores_ambient_context(agent_db: AgentDB, org_id: str) -> None: + # Ambient-context resolution lives in the service layer, not the repo. Repo trusts the param. + workflow = await agent_db.workflows.create_workflow( + title="wf", + workflow_definition={"parameters": [], "blocks": []}, + organization_id=org_id, + ) + ambient = skyvern_context.SkyvernContext(copilot_session_id="chat_from_ctx") + with skyvern_context.scoped(ambient): + run = await agent_db.workflow_runs.create_workflow_run( + workflow_permanent_id=workflow.workflow_permanent_id, + workflow_id=workflow.workflow_id, + organization_id=org_id, + ) + assert run.copilot_session_id is None + + +# --------------------------------------------------------------------------- +# Stub-heuristic regression coverage +# --------------------------------------------------------------------------- + + +def _make_workflow_stub(*, version: int, created_by: str | None, block_count: int) -> Any: + blocks = [object()] * block_count + definition = type("D", (), {"blocks": blocks})() + return type( + "W", + (), + {"version": version, "created_by": created_by, "workflow_definition": definition}, + )() + + +def test_is_copilot_born_stub_true_on_version_one_empty_unstamped() -> None: + from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write + + wf = _make_workflow_stub(version=1, created_by=None, block_count=0) + assert is_copilot_born_initial_write(wf) is True + + +def test_is_copilot_born_stub_false_on_later_version() -> None: + # v1 is the only version that can be copilot-born; cleared v2+ would otherwise false-positive. + from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write + + wf = _make_workflow_stub(version=2, created_by=None, block_count=0) + assert is_copilot_born_initial_write(wf) is False + + +def test_is_copilot_born_stub_false_on_already_stamped() -> None: + from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write + + wf = _make_workflow_stub(version=1, created_by="copilot", block_count=0) + assert is_copilot_born_initial_write(wf) is False + + +def test_is_copilot_born_stub_false_on_non_empty_definition() -> None: + from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write + + wf = _make_workflow_stub(version=1, created_by=None, block_count=3) + assert is_copilot_born_initial_write(wf) is False + + +def test_is_copilot_born_stub_false_on_none() -> None: + from skyvern.forge.sdk.copilot.attribution import is_copilot_born_initial_write + + assert is_copilot_born_initial_write(None) is False diff --git a/tests/unit/test_copilot_session_injection.py b/tests/unit/test_copilot_session_injection.py index be0fdc20f..fe75c1daa 100644 --- a/tests/unit/test_copilot_session_injection.py +++ b/tests/unit/test_copilot_session_injection.py @@ -138,11 +138,7 @@ def test_mcp_to_copilot_error() -> None: class TestMcpBrowserContextBridge: - """Bridge-specific behavior of mcp_browser_context (not scoped_session). - - Covers: copilot session registry, API-key override install/reset, and the - teardown guarantees that must hold under every failure mode. - """ + """Bridge-specific behavior of mcp_browser_context.""" def _install_happy_path_mocks( self, monkeypatch: pytest.MonkeyPatch @@ -374,6 +370,7 @@ class TestUpdateWorkflowDirect: mock_wf_service = MagicMock() mock_wf_service.update_workflow_definition = AsyncMock() + mock_wf_service.get_workflow = AsyncMock(return_value=None) monkeypatch.setattr("skyvern.forge.sdk.copilot.tools.app.WORKFLOW_SERVICE", mock_wf_service) yaml_str = "title: Test\nworkflow_definition:\n blocks: []" @@ -405,6 +402,7 @@ class TestUpdateWorkflowDirect: mock_wf_service = MagicMock() mock_wf_service.update_workflow_definition = AsyncMock() + mock_wf_service.get_workflow = AsyncMock(return_value=None) monkeypatch.setattr("skyvern.forge.sdk.copilot.tools.app.WORKFLOW_SERVICE", mock_wf_service) result = await _update_workflow({"workflow_yaml": "title: Test"}, ctx) diff --git a/tests/unit/test_copilot_session_span_tag.py b/tests/unit/test_copilot_session_span_tag.py new file mode 100644 index 000000000..72358c28c --- /dev/null +++ b/tests/unit/test_copilot_session_span_tag.py @@ -0,0 +1,146 @@ +"""Tests for the copilot.session_id span attribute on LLM spans.""" + +from __future__ import annotations + +from types import ModuleType, SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +from skyvern.forge.sdk.api.llm.api_handler_factory import _enrich_llm_span +from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.skyvern_context import SkyvernContext + + +def _call_enrich(span: MagicMock) -> None: + _enrich_llm_span( + span, + model="gpt-5", + prompt_name="workflow-copilot", + prompt_tokens=10, + completion_tokens=20, + reasoning_tokens=0, + cached_tokens=0, + latency_ms=100, + llm_cost=0.001, + ) + + +def _set_attribute_keys(span: MagicMock) -> list[str]: + return [call.args[0] for call in span.set_attribute.call_args_list if call.args] + + +class TestEnrichLlmSpan: + def test_stamps_attribute_when_context_has_session_id(self) -> None: + span = MagicMock() + with skyvern_context.scoped(SkyvernContext(copilot_session_id="chat_xyz")): + _call_enrich(span) + span.set_attribute.assert_any_call("copilot.session_id", "chat_xyz") + + def test_no_attribute_when_context_has_no_session_id(self) -> None: + span = MagicMock() + with skyvern_context.scoped(SkyvernContext(copilot_session_id=None)): + _call_enrich(span) + assert "copilot.session_id" not in _set_attribute_keys(span) + + def test_no_attribute_when_no_ambient_context(self) -> None: + span = MagicMock() + skyvern_context.reset() + _call_enrich(span) + assert "copilot.session_id" not in _set_attribute_keys(span) + + +class _FakeAgentSpanData: + def __init__(self, name: str = "workflow-copilot") -> None: + self.name = name + + +class _FakeGenerationSpanData: + pass + + +class _FakeFunctionSpanData: + def __init__(self, name: str = "some_tool") -> None: + self.name = name + + +def _install_patch(monkeypatch: Any) -> Any: + # Wire ModuleType stubs for the full logfire chain — sys.modules entries alone aren't enough. + import sys + + from skyvern.forge.sdk.copilot import tracing_setup + + def _fake_original(span_data: Any, msg_template: str) -> dict[str, Any]: + attrs: dict[str, Any] = {} + if isinstance(span_data, _FakeAgentSpanData): + attrs["name"] = span_data.name + if isinstance(span_data, _FakeFunctionSpanData): + attrs["name"] = span_data.name + return attrs + + class _FakeWrapper: + @staticmethod + def create_span(*args: Any, **kwargs: Any) -> Any: + return None + + logfire_mod = ModuleType("logfire") + internal_mod = ModuleType("logfire._internal") + integrations_mod = ModuleType("logfire._internal.integrations") + oai_mod = ModuleType("logfire._internal.integrations.openai_agents") + oai_mod.attributes_from_span_data = _fake_original # type: ignore[attr-defined] + oai_mod.LogfireTraceProviderWrapper = _FakeWrapper # type: ignore[attr-defined] + logfire_mod._internal = internal_mod # type: ignore[attr-defined] + internal_mod.integrations = integrations_mod # type: ignore[attr-defined] + integrations_mod.openai_agents = oai_mod # type: ignore[attr-defined] + + monkeypatch.setitem(sys.modules, "logfire", logfire_mod) + monkeypatch.setitem(sys.modules, "logfire._internal", internal_mod) + monkeypatch.setitem(sys.modules, "logfire._internal.integrations", integrations_mod) + monkeypatch.setitem(sys.modules, "logfire._internal.integrations.openai_agents", oai_mod) + monkeypatch.setitem( + sys.modules, + "agents", + SimpleNamespace( + AgentSpanData=_FakeAgentSpanData, + GenerationSpanData=_FakeGenerationSpanData, + FunctionSpanData=_FakeFunctionSpanData, + ), + ) + + tracing_setup._patch_agent_span_attributes() + return oai_mod.attributes_from_span_data + + +class TestPatchedSpanAttributes: + def test_stamps_on_agent_span_when_context_has_session_id(self, monkeypatch: Any) -> None: + patched = _install_patch(monkeypatch) + with skyvern_context.scoped(SkyvernContext(copilot_session_id="chat_xyz")): + attrs = patched(_FakeAgentSpanData(), "Agent run: {name!r}") + assert attrs["copilot.session_id"] == "chat_xyz" + + def test_stamps_on_generation_span_when_context_has_session_id(self, monkeypatch: Any) -> None: + patched = _install_patch(monkeypatch) + with skyvern_context.scoped(SkyvernContext(copilot_session_id="chat_xyz")): + attrs = patched(_FakeGenerationSpanData(), "Generation") + assert attrs["copilot.session_id"] == "chat_xyz" + + def test_stamps_on_function_span_when_context_has_session_id(self, monkeypatch: Any) -> None: + patched = _install_patch(monkeypatch) + with skyvern_context.scoped(SkyvernContext(copilot_session_id="chat_xyz")): + attrs = patched(_FakeFunctionSpanData(), "Function call") + assert attrs["copilot.session_id"] == "chat_xyz" + + def test_no_attribute_when_context_has_no_session_id(self, monkeypatch: Any) -> None: + patched = _install_patch(monkeypatch) + with skyvern_context.scoped(SkyvernContext(copilot_session_id=None)): + attrs_agent = patched(_FakeAgentSpanData(), "Agent run: {name!r}") + attrs_gen = patched(_FakeGenerationSpanData(), "Generation") + attrs_fn = patched(_FakeFunctionSpanData(), "Function call") + assert "copilot.session_id" not in attrs_agent + assert "copilot.session_id" not in attrs_gen + assert "copilot.session_id" not in attrs_fn + + def test_no_attribute_when_no_ambient_context(self, monkeypatch: Any) -> None: + patched = _install_patch(monkeypatch) + skyvern_context.reset() + attrs = patched(_FakeAgentSpanData(), "Agent run: {name!r}") + assert "copilot.session_id" not in attrs diff --git a/tests/unit/test_copilot_task_block_rejection.py b/tests/unit/test_copilot_task_block_rejection.py index 701741067..40e10e09d 100644 --- a/tests/unit/test_copilot_task_block_rejection.py +++ b/tests/unit/test_copilot_task_block_rejection.py @@ -303,6 +303,7 @@ async def test_update_workflow_preserves_legacy_task_block_under_unchanged_label patch("skyvern.forge.sdk.copilot.tools._process_workflow_yaml", return_value=fake_workflow), patch("skyvern.forge.sdk.copilot.tools.app") as mock_app, ): + mock_app.WORKFLOW_SERVICE.get_workflow = AsyncMock(return_value=None) mock_app.WORKFLOW_SERVICE.update_workflow_definition = AsyncMock() result = await _update_workflow({"workflow_yaml": submitted}, ctx) @@ -343,6 +344,7 @@ async def test_update_workflow_allows_all_allowed_block_types() -> None: patch("skyvern.forge.sdk.copilot.tools._process_workflow_yaml", return_value=fake_workflow), patch("skyvern.forge.sdk.copilot.tools.app") as mock_app, ): + mock_app.WORKFLOW_SERVICE.get_workflow = AsyncMock(return_value=None) mock_app.WORKFLOW_SERVICE.update_workflow_definition = AsyncMock() result = await _update_workflow({"workflow_yaml": submitted}, ctx) diff --git a/tests/unit/test_workflow_copilot_session_context.py b/tests/unit/test_workflow_copilot_session_context.py new file mode 100644 index 000000000..a1b543a80 --- /dev/null +++ b/tests/unit/test_workflow_copilot_session_context.py @@ -0,0 +1,49 @@ +"""Tests for the bind_copilot_session_id context manager.""" + +from __future__ import annotations + +import pytest + +from skyvern.forge.sdk.core import skyvern_context +from skyvern.forge.sdk.core.skyvern_context import SkyvernContext +from skyvern.forge.sdk.routes.workflow_copilot import bind_copilot_session_id + + +class TestBindCopilotSessionId: + def test_sets_id_during_scope_when_ambient_context_present(self) -> None: + with skyvern_context.scoped(SkyvernContext(copilot_session_id=None)): + with bind_copilot_session_id("chat_xyz"): + ctx = skyvern_context.current() + assert ctx is not None + assert ctx.copilot_session_id == "chat_xyz" + + def test_restores_prior_value_on_normal_exit(self) -> None: + with skyvern_context.scoped(SkyvernContext(copilot_session_id="outer")): + with bind_copilot_session_id("inner"): + assert skyvern_context.current().copilot_session_id == "inner" # type: ignore[union-attr] + assert skyvern_context.current().copilot_session_id == "outer" # type: ignore[union-attr] + + def test_restores_prior_value_when_body_raises(self) -> None: + class _Boom(RuntimeError): + pass + + with skyvern_context.scoped(SkyvernContext(copilot_session_id="outer")): + with pytest.raises(_Boom): + with bind_copilot_session_id("inner"): + raise _Boom("body raised") + assert skyvern_context.current().copilot_session_id == "outer" # type: ignore[union-attr] + + def test_noop_when_chat_id_is_none(self) -> None: + with skyvern_context.scoped(SkyvernContext(copilot_session_id="outer")): + with bind_copilot_session_id(None): + # No overwrite — the outer value must stick. + assert skyvern_context.current().copilot_session_id == "outer" # type: ignore[union-attr] + assert skyvern_context.current().copilot_session_id == "outer" # type: ignore[union-attr] + + def test_noop_when_no_ambient_context(self) -> None: + skyvern_context.reset() + # Helper must not raise when there is no context to mutate — the + # copilot route should still function, just without the tag. + with bind_copilot_session_id("chat_xyz"): + assert skyvern_context.current() is None + assert skyvern_context.current() is None