From 3fce82760fd1f71600131f7ecdfcff774c722707 Mon Sep 17 00:00:00 2001 From: pedrohsdb Date: Fri, 24 Apr 2026 17:02:12 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20persist=20real=20LLM=20cost=20to=20DB?= =?UTF-8?q?=20+=20block-level=20attribution=20(public=20API=20unchanged=20?= =?UTF-8?q?=E2=80=94=20placeholder=20preserved)=20(#5656)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ...560_add_llm_cost_to_workflow_run_blocks.py | 30 ++++++ skyvern/forge/sdk/api/llm/api_handler.py | 2 + .../forge/sdk/api/llm/api_handler_factory.py | 71 ++++++++++-- skyvern/forge/sdk/db/models.py | 3 + skyvern/forge/sdk/db/repositories/observer.py | 63 ++++++++++- skyvern/forge/sdk/db/repositories/tasks.py | 17 +++ skyvern/forge/sdk/workflow/models/block.py | 67 ++++++++++-- skyvern/forge/sdk/workflow/service.py | 39 ++++++- tests/unit/test_api_handler_factory.py | 18 ++-- .../unit/test_block_llm_cost_kwargs_wiring.py | 101 ++++++++++++++++++ 10 files changed, 383 insertions(+), 28 deletions(-) create mode 100644 alembic/versions/2026_04_24_2353-c19d7d385560_add_llm_cost_to_workflow_run_blocks.py create mode 100644 tests/unit/test_block_llm_cost_kwargs_wiring.py diff --git a/alembic/versions/2026_04_24_2353-c19d7d385560_add_llm_cost_to_workflow_run_blocks.py b/alembic/versions/2026_04_24_2353-c19d7d385560_add_llm_cost_to_workflow_run_blocks.py new file mode 100644 index 000000000..11e48eebf --- /dev/null +++ b/alembic/versions/2026_04_24_2353-c19d7d385560_add_llm_cost_to_workflow_run_blocks.py @@ -0,0 +1,30 @@ +"""add llm_cost to workflow_run_blocks + +Revision ID: c19d7d385560 +Revises: c9005bafa5ec +Create Date: 2026-04-24T23:53:46.912017+00:00 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c19d7d385560" +down_revision: Union[str, None] = "c9005bafa5ec" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "workflow_run_blocks", + sa.Column("llm_cost", sa.Numeric(), nullable=False, server_default="0"), + ) + + +def downgrade() -> None: + op.drop_column("workflow_run_blocks", "llm_cost") diff --git a/skyvern/forge/sdk/api/llm/api_handler.py b/skyvern/forge/sdk/api/llm/api_handler.py index 1d0305389..48fa152be 100644 --- a/skyvern/forge/sdk/api/llm/api_handler.py +++ b/skyvern/forge/sdk/api/llm/api_handler.py @@ -15,6 +15,7 @@ class LLMAPIHandler(Protocol): task_v2: TaskV2 | None = None, thought: Thought | None = None, ai_suggestion: AISuggestion | None = None, + workflow_run_block_id: str | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, @@ -34,6 +35,7 @@ async def dummy_llm_api_handler( task_v2: TaskV2 | None = None, thought: Thought | None = None, ai_suggestion: AISuggestion | None = None, + workflow_run_block_id: str | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, diff --git a/skyvern/forge/sdk/api/llm/api_handler_factory.py b/skyvern/forge/sdk/api/llm/api_handler_factory.py index 004a6c1b8..e8f3445ff 100644 --- a/skyvern/forge/sdk/api/llm/api_handler_factory.py +++ b/skyvern/forge/sdk/api/llm/api_handler_factory.py @@ -222,13 +222,45 @@ def _normalize_llm_model(model: str | None) -> str | None: return model.split("/")[-1] -def _assert_step_thought_exclusive(step: Step | None, thought: Thought | None) -> None: - # step and thought write the same llm_cost to different tables - # (steps.step_cost vs observer_thoughts.thought_cost). int_org_llm_costs - # UNION ALLs them, so setting both would double-count cost in - # fct_org_margin.llm_cost. - if step is not None and thought is not None: - raise ValueError("LLM API handler invoked with both step and thought set — these are mutually exclusive") +def _assert_step_thought_block_exclusive( + step: Step | None, + thought: Thought | None, + workflow_run_block_id: str | None, +) -> None: + # Each LLM call writes cost to exactly one of: steps.step_cost, + # observer_thoughts.thought_cost, workflow_run_blocks.llm_cost. + # Both the run-level SUM and the int_org_llm_costs dbt model rely on + # this exclusivity to avoid double-counting. + set_count = sum(1 for x in (step, thought, workflow_run_block_id) if x is not None) + if set_count > 1: + raise ValueError( + "LLM API handler invoked with more than one of step / thought / workflow_run_block_id set — " + "these are mutually exclusive" + ) + + +async def _persist_block_llm_cost( + workflow_run_block_id: str, + organization_id: str | None, + context: skyvern_context.SkyvernContext | None, + llm_cost: float, + prompt_name: str | None, +) -> None: + """Increment workflow_run_blocks.llm_cost or warn if no org_id resolves.""" + block_org_id = organization_id or (context.organization_id if context else None) + if block_org_id: + await app.DATABASE.observer.increment_workflow_run_block_llm_cost( + workflow_run_block_id=workflow_run_block_id, + organization_id=block_org_id, + amount=llm_cost, + ) + else: + LOG.warning( + "Block LLM cost dropped: workflow_run_block_id set but no organization_id resolved", + workflow_run_block_id=workflow_run_block_id, + llm_cost=llm_cost, + prompt_name=prompt_name, + ) def _convert_allowed_fails_policy(policy: LLMAllowedFailsPolicy | None) -> AllowedFailsPolicy | None: @@ -562,6 +594,7 @@ class LLMAPIHandlerFactory: task_v2: TaskV2 | None = None, thought: Thought | None = None, ai_suggestion: AISuggestion | None = None, + workflow_run_block_id: str | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, @@ -584,7 +617,7 @@ class LLMAPIHandlerFactory: Returns: The response from the LLM router. """ - _assert_step_thought_exclusive(step, thought) + _assert_step_thought_block_exclusive(step, thought, workflow_run_block_id) start_time = time.perf_counter() _llm_span = otel_trace.get_current_span() _llm_span.set_attribute("llm_key", llm_key) @@ -1000,6 +1033,12 @@ class LLMAPIHandlerFactory: cached_token_count=cached_tokens if cached_tokens > 0 else None, last_llm_model=actual_model, ) + if workflow_run_block_id: + # Atomic UPDATE: description gen (asyncio.create_task in + # execute_safe) races with the block's own execute() calls. + await _persist_block_llm_cost( + workflow_run_block_id, organization_id, context, llm_cost, prompt_name + ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) parsed_response_json = json.dumps(parsed_response, indent=2) if should_persist_llm_artifacts: @@ -1154,6 +1193,7 @@ class LLMAPIHandlerFactory: task_v2: TaskV2 | None = None, thought: Thought | None = None, ai_suggestion: AISuggestion | None = None, + workflow_run_block_id: str | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, @@ -1164,7 +1204,7 @@ class LLMAPIHandlerFactory: force_dict: bool = True, system_prompt: str | None = None, ) -> dict[str, Any] | Any: - _assert_step_thought_exclusive(step, thought) + _assert_step_thought_block_exclusive(step, thought, workflow_run_block_id) start_time = time.perf_counter() _llm_span = otel_trace.get_current_span() # handler_type distinguishes the three LLM entry points that share @@ -1527,6 +1567,10 @@ class LLMAPIHandlerFactory: thought_cost=llm_cost, last_llm_model=actual_model, ) + if workflow_run_block_id: + await _persist_block_llm_cost( + workflow_run_block_id, organization_id, context, llm_cost, prompt_name + ) parsed_response = parse_api_response(response, llm_config.add_assistant_prefix, force_dict) parsed_response_json = json.dumps(parsed_response, indent=2) if should_persist_llm_artifacts: @@ -1749,6 +1793,7 @@ class LLMCaller: task_v2: TaskV2 | None = None, thought: Thought | None = None, ai_suggestion: AISuggestion | None = None, + workflow_run_block_id: str | None = None, screenshots: list[bytes] | None = None, parameters: dict[str, Any] | None = None, organization_id: str | None = None, @@ -1760,7 +1805,7 @@ class LLMCaller: system_prompt: str | None = None, **extra_parameters: Any, ) -> dict[str, Any] | Any: - _assert_step_thought_exclusive(step, thought) + _assert_step_thought_block_exclusive(step, thought, workflow_run_block_id) start_time = time.perf_counter() _llm_span = otel_trace.get_current_span() _llm_span.set_attribute("llm_key", self.llm_key) @@ -1996,6 +2041,12 @@ class LLMCaller: thought_cost=call_stats.llm_cost, last_llm_model=actual_model, ) + if workflow_run_block_id and call_stats and call_stats.llm_cost is not None: + # call_stats.llm_cost is None when litellm can't compute cost + # (volcengine, some OPENAI_COMPATIBLE targets). + await _persist_block_llm_cost( + workflow_run_block_id, organization_id, context, call_stats.llm_cost, prompt_name + ) organization_id = organization_id or ( step.organization_id if step else (thought.organization_id if thought else None) diff --git a/skyvern/forge/sdk/db/models.py b/skyvern/forge/sdk/db/models.py index de2ea1091..707f5a8e4 100644 --- a/skyvern/forge/sdk/db/models.py +++ b/skyvern/forge/sdk/db/models.py @@ -824,6 +824,9 @@ class WorkflowRunBlockModel(Base): executed_branch_result = Column(Boolean, nullable=True) executed_branch_next_block = Column(String, nullable=True) + # Accumulates LLM cost for block-scoped calls (no step/thought attribution). + llm_cost = Column(Numeric, default=0, nullable=False) + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False) diff --git a/skyvern/forge/sdk/db/repositories/observer.py b/skyvern/forge/sdk/db/repositories/observer.py index da44a54c1..fdb1a3ca9 100644 --- a/skyvern/forge/sdk/db/repositories/observer.py +++ b/skyvern/forge/sdk/db/repositories/observer.py @@ -3,7 +3,8 @@ from __future__ import annotations from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Callable -from sqlalchemy import and_, delete, select +import structlog +from sqlalchemy import and_, delete, func, select, update from sqlalchemy.exc import SQLAlchemyError from skyvern.forge.sdk.db._error_handling import db_operation @@ -30,6 +31,8 @@ from skyvern.forge.sdk.schemas.workflow_runs import WorkflowRunBlock from skyvern.schemas.runs import ProxyLocationInput, RunEngine from skyvern.schemas.workflows import BlockStatus, BlockType +LOG = structlog.get_logger() + class ObserverRepository(BaseRepository): """Database operations for observer tasks (TaskV2), thoughts, and workflow run blocks.""" @@ -120,6 +123,64 @@ class ObserverRepository(BaseRepository): thoughts = (await session.scalars(query)).all() return [Thought.model_validate(thought) for thought in thoughts] + @db_operation("get_thought_cost_sum_by_workflow_run_id") + async def get_thought_cost_sum_by_workflow_run_id(self, workflow_run_id: str, organization_id: str) -> float: + """Sum `thought_cost` across all thoughts for the given workflow_run_id. + + Returns 0.0 for runs without task_v2 planning. + """ + async with self.Session() as session: + query = ( + select(func.coalesce(func.sum(ThoughtModel.thought_cost), 0)) + .where(ThoughtModel.workflow_run_id == workflow_run_id) + .where(ThoughtModel.organization_id == organization_id) + ) + total = (await session.execute(query)).scalar_one() + return float(total) + + @db_operation("get_block_llm_cost_sum_by_workflow_run_id") + async def get_block_llm_cost_sum_by_workflow_run_id(self, workflow_run_id: str, organization_id: str) -> float: + """Sum `llm_cost` across all workflow_run_blocks for this workflow_run_id.""" + async with self.Session() as session: + query = ( + select(func.coalesce(func.sum(WorkflowRunBlockModel.llm_cost), 0)) + .where(WorkflowRunBlockModel.workflow_run_id == workflow_run_id) + .where(WorkflowRunBlockModel.organization_id == organization_id) + ) + total = (await session.execute(query)).scalar_one() + return float(total) + + @db_operation("increment_workflow_run_block_llm_cost") + async def increment_workflow_run_block_llm_cost( + self, + workflow_run_block_id: str, + organization_id: str, + amount: float, + ) -> None: + """Atomically add `amount` to `workflow_run_blocks.llm_cost`. + + Single SQL UPDATE so concurrent writers don't lose increments. + No-op for non-positive `amount`. + """ + if amount <= 0: + return + async with self.Session() as session: + stmt = ( + update(WorkflowRunBlockModel) + .where(WorkflowRunBlockModel.workflow_run_block_id == workflow_run_block_id) + .where(WorkflowRunBlockModel.organization_id == organization_id) + .values(llm_cost=WorkflowRunBlockModel.llm_cost + amount) + ) + result = await session.execute(stmt) + await session.commit() + if result.rowcount == 0: + LOG.warning( + "Block LLM cost increment matched zero rows — cost dropped", + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + amount=amount, + ) + @db_operation("create_task_v2") async def create_task_v2( self, diff --git a/skyvern/forge/sdk/db/repositories/tasks.py b/skyvern/forge/sdk/db/repositories/tasks.py index 4504f872d..eac12d56e 100644 --- a/skyvern/forge/sdk/db/repositories/tasks.py +++ b/skyvern/forge/sdk/db/repositories/tasks.py @@ -229,6 +229,23 @@ class TasksRepository(BaseRepository): row = (await session.execute(query)).one() return row.total, row.completed + @db_operation("get_step_cost_sum_by_task_ids") + async def get_step_cost_sum_by_task_ids(self, task_ids: list[str], organization_id: str) -> float: + """Sum `step_cost` across all steps belonging to the given task_ids. + + Returns 0.0 for empty task_ids. Includes failed steps. + """ + if not task_ids: + return 0.0 + async with self.Session() as session: + query = ( + select(func.coalesce(func.sum(StepModel.step_cost), 0)) + .where(StepModel.task_id.in_(task_ids)) + .where(StepModel.organization_id == organization_id) + ) + total = (await session.execute(query)).scalar_one() + return float(total) + @db_operation("get_workflow_run_progress_timestamps") async def get_workflow_run_progress_timestamps( self, diff --git a/skyvern/forge/sdk/workflow/models/block.py b/skyvern/forge/sdk/workflow/models/block.py index 6419304dc..9a1eb9282 100644 --- a/skyvern/forge/sdk/workflow/models/block.py +++ b/skyvern/forge/sdk/workflow/models/block.py @@ -568,7 +568,10 @@ class Block(BaseModel, abc.ABC): block=block_data, ) json_response = await app.SECONDARY_LLM_API_HANDLER( - prompt=description_generation_prompt, prompt_name="generate-workflow-run-block-description" + prompt=description_generation_prompt, + prompt_name="generate-workflow-run-block-description", + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, ) description = json_response.get("summary") LOG.info( @@ -2833,7 +2836,12 @@ class TextPromptBlock(Block): prompt=prompt, llm_key=self.llm_key, ) - response = await llm_api_handler(prompt=prompt, prompt_name="text-prompt") + response = await llm_api_handler( + prompt=prompt, + prompt_name="text-prompt", + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + ) if workflow_run_block: artifacts_to_persist.append((ArtifactType.LLM_RESPONSE, json.dumps(response).encode("utf-8"))) @@ -3980,7 +3988,12 @@ class FileParserBlock(Block): file_url=self.file_url, file_type=self.file_type, error=f"Failed to parse Excel file: {str(e)}" ) - async def _parse_pdf_file(self, file_path: str) -> str: + async def _parse_pdf_file( + self, + file_path: str, + workflow_run_block_id: str | None = None, + organization_id: str | None = None, + ) -> str: """Parse PDF file and return extracted text. Uses the shared PDF parsing utility that tries pypdf first, @@ -4016,6 +4029,8 @@ class FileParserBlock(Block): prompt_name="extract-text-from-image", screenshots=page_images, force_dict=True, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, ) return llm_response.get("extracted_text", "") except Exception: @@ -4025,7 +4040,12 @@ class FileParserBlock(Block): ) raise - async def _parse_image_file(self, file_path: str) -> str: + async def _parse_image_file( + self, + file_path: str, + workflow_run_block_id: str | None = None, + organization_id: str | None = None, + ) -> str: """Parse image file using vision LLM for OCR.""" try: with open(file_path, "rb") as f: @@ -4040,6 +4060,8 @@ class FileParserBlock(Block): prompt_name="extract-text-from-image", screenshots=[image_bytes], force_dict=True, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, ) return llm_response.get("extracted_text", "") except Exception: @@ -4117,7 +4139,11 @@ class FileParserBlock(Block): ) async def _extract_with_ai( - self, content: str | list[dict[str, Any]], workflow_run_context: WorkflowRunContext + self, + content: str | list[dict[str, Any]], + workflow_run_context: WorkflowRunContext, + workflow_run_block_id: str | None = None, + organization_id: str | None = None, ) -> dict[str, Any]: """Extract structured data using AI based on json_schema.""" # Use local variable to avoid mutating the instance @@ -4146,7 +4172,11 @@ class FileParserBlock(Block): llm_api_handler = LLMAPIHandlerFactory.get_override_llm_api_handler(llm_key, default=app.LLM_API_HANDLER) llm_response = await llm_api_handler( - prompt=llm_prompt, prompt_name="extract-information-from-file-text", force_dict=False + prompt=llm_prompt, + prompt_name="extract-information-from-file-text", + force_dict=False, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, ) return llm_response @@ -4223,9 +4253,17 @@ class FileParserBlock(Block): elif self.file_type == FileType.EXCEL: parsed_data = await self._parse_excel_file(file_path) elif self.file_type == FileType.PDF: - parsed_data = await self._parse_pdf_file(file_path) + parsed_data = await self._parse_pdf_file( + file_path, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + ) elif self.file_type == FileType.IMAGE: - parsed_data = await self._parse_image_file(file_path) + parsed_data = await self._parse_image_file( + file_path, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + ) elif self.file_type == FileType.DOCX: parsed_data = await self._parse_docx_file(file_path) else: @@ -4250,7 +4288,12 @@ class FileParserBlock(Block): if self.json_schema: try: - ai_extracted_data = await self._extract_with_ai(parsed_data, workflow_run_context) + ai_extracted_data = await self._extract_with_ai( + parsed_data, + workflow_run_context, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, + ) final_data = ai_extracted_data except Exception as e: return await self.build_block_result( @@ -4370,7 +4413,11 @@ class PDFParserBlock(Block): "extract-information-from-file-text", extracted_text_content=extracted_text, json_schema=self.json_schema ) llm_response = await app.LLM_API_HANDLER( - prompt=llm_prompt, prompt_name="extract-information-from-file-text", force_dict=False + prompt=llm_prompt, + prompt_name="extract-information-from-file-text", + force_dict=False, + workflow_run_block_id=workflow_run_block_id, + organization_id=organization_id, ) # Record the parsed data await self.record_output_parameter_value(workflow_run_context, workflow_run_id, llm_response) diff --git a/skyvern/forge/sdk/workflow/service.py b/skyvern/forge/sdk/workflow/service.py index ab65c744a..8ef33ad71 100644 --- a/skyvern/forge/sdk/workflow/service.py +++ b/skyvern/forge/sdk/workflow/service.py @@ -4428,6 +4428,43 @@ class WorkflowService: urls = await app.ARTIFACT_MANAGER.get_share_links_with_bundle_support(artifacts) return [u for u in urls if u is not None] + async def get_workflow_run_llm_cost_sum( + self, + workflow_run_id: str, + organization_id: str, + ) -> float: + """Sum per-LLM-call cost across step_cost, thought_cost, and + workflow_run_blocks.llm_cost for this workflow_run. + + `organization_id` is required: passing None makes repo filters + evaluate as `IS NULL` and silently return 0.0. + """ + if not organization_id: + raise ValueError( + "get_workflow_run_llm_cost_sum requires organization_id; " + "passing None would compile to IS NULL and silently return 0.0" + ) + # thought + block sums are independent of the task list; run them in + # parallel with the task fetch + step sum (which depends on task ids). + thought_task = app.DATABASE.observer.get_thought_cost_sum_by_workflow_run_id( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + block_task = app.DATABASE.observer.get_block_llm_cost_sum_by_workflow_run_id( + workflow_run_id=workflow_run_id, + organization_id=organization_id, + ) + workflow_run_tasks = await app.DATABASE.tasks.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id) + step_cost_sum, thought_cost_sum, block_llm_cost_sum = await asyncio.gather( + app.DATABASE.tasks.get_step_cost_sum_by_task_ids( + task_ids=[task.task_id for task in workflow_run_tasks], + organization_id=organization_id, + ), + thought_task, + block_task, + ) + return step_cost_sum + thought_cost_sum + block_llm_cost_sum + async def build_workflow_run_status_response_by_workflow_id( self, workflow_run_id: str, @@ -4600,7 +4637,7 @@ class WorkflowService: text_prompt_blocks = [ block for block in workflow_run_blocks if block.block_type == BlockType.TEXT_PROMPT ] - # TODO: This is a temporary cost calculation. We need to implement a more accurate cost calculation. + # This is a temporary cost calculation. total_cost = 0.05 * (completed_step_count + len(text_prompt_blocks)) return WorkflowRunResponseBase( workflow_id=workflow.workflow_permanent_id, diff --git a/tests/unit/test_api_handler_factory.py b/tests/unit/test_api_handler_factory.py index 8c61ed58b..abe633403 100644 --- a/tests/unit/test_api_handler_factory.py +++ b/tests/unit/test_api_handler_factory.py @@ -236,15 +236,21 @@ def test_normalize_llm_model_strips_provider_prefix() -> None: assert api_handler_factory._normalize_llm_model(None) is None -def test_assert_step_thought_exclusive_rejects_both_set() -> None: +def test_assert_step_thought_block_exclusive_rejects_both_set() -> None: with pytest.raises(ValueError, match="mutually exclusive"): - api_handler_factory._assert_step_thought_exclusive(MagicMock(), MagicMock()) + api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), MagicMock(), None) -def test_assert_step_thought_exclusive_allows_single_or_neither() -> None: - api_handler_factory._assert_step_thought_exclusive(None, None) - api_handler_factory._assert_step_thought_exclusive(MagicMock(), None) - api_handler_factory._assert_step_thought_exclusive(None, MagicMock()) +def test_assert_step_thought_block_exclusive_rejects_step_and_block() -> None: + with pytest.raises(ValueError, match="mutually exclusive"): + api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), None, "wfb_123") + + +def test_assert_step_thought_block_exclusive_allows_single_or_neither() -> None: + api_handler_factory._assert_step_thought_block_exclusive(None, None, None) + api_handler_factory._assert_step_thought_block_exclusive(MagicMock(), None, None) + api_handler_factory._assert_step_thought_block_exclusive(None, MagicMock(), None) + api_handler_factory._assert_step_thought_block_exclusive(None, None, "wfb_123") @pytest.mark.asyncio diff --git a/tests/unit/test_block_llm_cost_kwargs_wiring.py b/tests/unit/test_block_llm_cost_kwargs_wiring.py new file mode 100644 index 000000000..e406edefe --- /dev/null +++ b/tests/unit/test_block_llm_cost_kwargs_wiring.py @@ -0,0 +1,101 @@ +"""AST guard: every awaited LLM handler call inside the six +block-scoped methods below must pass both `workflow_run_block_id=` +and `organization_id=`. Catches call-site wiring regressions. +""" + +from __future__ import annotations + +import ast +import pathlib + +BLOCK_PY = ( + pathlib.Path(__file__).resolve().parents[2] / "skyvern" / "forge" / "sdk" / "workflow" / "models" / "block.py" +) + +# Methods that make block-scoped LLM calls. Each must pass +# workflow_run_block_id= and organization_id= to any LLM handler call +# in its body so the handler can associate the call with the block. +BLOCK_ATTRIBUTED_METHODS = { + "_generate_workflow_run_block_description", # BaseBlock description gen + "send_prompt", # TextPromptBlock + "_parse_pdf_file", # FileParserBlock PDF vision + "_parse_image_file", # FileParserBlock image OCR + "_extract_with_ai", # FileParserBlock schema extract + "execute", # PDFParserBlock (deprecated) execute +} + +# Names that indicate an LLM handler call (both module-level handlers and +# locally-bound `llm_api_handler` from LLMAPIHandlerFactory.get_override_llm_api_handler). +LLM_HANDLER_NAME_FRAGMENTS = ("LLM_API_HANDLER", "llm_api_handler") + + +def _is_llm_handler_call(call: ast.Call) -> bool: + """True if the call looks like `await (whatever).LLM_API_HANDLER(...)` or + `await llm_api_handler(...)`.""" + func = call.func + if isinstance(func, ast.Attribute): + return any(frag in func.attr for frag in LLM_HANDLER_NAME_FRAGMENTS) + if isinstance(func, ast.Name): + return any(frag in func.id for frag in LLM_HANDLER_NAME_FRAGMENTS) + return False + + +def _find_llm_calls_in_method(method: ast.AsyncFunctionDef) -> list[ast.Call]: + """Collect LLM handler INVOCATIONS (awaited calls) only — not factory + lookups like `LLMAPIHandlerFactory.get_override_llm_api_handler(...)` + which happen to contain `llm_api_handler` in the name but aren't the + thing that makes an LLM request.""" + calls: list[ast.Call] = [] + for node in ast.walk(method): + if not isinstance(node, ast.Await): + continue + inner = node.value + if isinstance(inner, ast.Call) and _is_llm_handler_call(inner): + calls.append(inner) + return calls + + +def _kwarg_names(call: ast.Call) -> set[str]: + return {kw.arg for kw in call.keywords if kw.arg is not None} + + +def test_every_block_scoped_llm_call_passes_both_cost_attribution_kwargs() -> None: + tree = ast.parse(BLOCK_PY.read_text()) + + offenders: list[str] = [] + methods_with_calls: set[str] = set() + + for node in ast.walk(tree): + if not isinstance(node, ast.AsyncFunctionDef): + continue + if node.name not in BLOCK_ATTRIBUTED_METHODS: + continue + + for call in _find_llm_calls_in_method(node): + methods_with_calls.add(node.name) + kwargs = _kwarg_names(call) + missing = {"workflow_run_block_id", "organization_id"} - kwargs + if missing: + offenders.append( + f"{node.name} @ line {call.lineno}: missing {sorted(missing)} " + f"on LLM handler call (kwargs present: {sorted(kwargs)})" + ) + + # Sanity: every method in BLOCK_ATTRIBUTED_METHODS must have at least one + # awaited LLM call. If our name-fragment matcher misses (e.g. a handler + # rename), the kwargs check would silently pass with zero calls inspected. + assert methods_with_calls == BLOCK_ATTRIBUTED_METHODS, ( + f"Expected every method in BLOCK_ATTRIBUTED_METHODS to contain at least one " + f"awaited LLM handler call. Methods with no calls found: " + f"{sorted(BLOCK_ATTRIBUTED_METHODS - methods_with_calls)}. " + f"Either the matcher needs a new name fragment, or the method no longer " + f"makes a block-scoped LLM call (remove from BLOCK_ATTRIBUTED_METHODS)." + ) + + assert not offenders, ( + "Block-scoped LLM calls are missing cost-attribution kwargs:\n " + + "\n ".join(offenders) + + "\n\nEvery LLM handler call inside these methods must pass both " + "`workflow_run_block_id=` and `organization_id=` to correctly attribute " + "cost to `workflow_run_blocks.llm_cost`." + )