TOTP code db + agent support for fetching totp_code from db (#784)

Co-authored-by: Shuchang Zheng <wintonzheng0325@gmail.com>
This commit is contained in:
Kerem Yilmaz 2024-09-08 15:07:03 -07:00 committed by GitHub
parent d878ee5a0d
commit b9f5e33876
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 243 additions and 26 deletions

View file

@ -0,0 +1,69 @@
"""create totp_codes table and add task.totp_identifier
Revision ID: c5848cc524b1
Revises: c50f0aa0ef24
Create Date: 2024-09-08 21:59:56.666276+00:00
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c5848cc524b1"
down_revision: Union[str, None] = "c50f0aa0ef24"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"totp_codes",
sa.Column("totp_code_id", sa.String(), nullable=False),
sa.Column("totp_identifier", sa.String(), nullable=False),
sa.Column("organization_id", sa.String(), nullable=True),
sa.Column("task_id", sa.String(), nullable=True),
sa.Column("workflow_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=False),
sa.Column("code", sa.String(), nullable=False),
sa.Column("source", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column("modified_at", sa.DateTime(), nullable=False),
sa.Column("expired_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.organization_id"],
),
sa.ForeignKeyConstraint(
["task_id"],
["tasks.task_id"],
),
sa.ForeignKeyConstraint(
["workflow_id"],
["workflows.workflow_id"],
),
sa.PrimaryKeyConstraint("totp_code_id"),
)
op.create_index(op.f("ix_totp_codes_created_at"), "totp_codes", ["created_at"], unique=False)
op.create_index(op.f("ix_totp_codes_expired_at"), "totp_codes", ["expired_at"], unique=False)
op.create_index(op.f("ix_totp_codes_totp_identifier"), "totp_codes", ["totp_identifier"], unique=False)
op.add_column("tasks", sa.Column("totp_identifier", sa.String(), nullable=True))
op.add_column("workflow_runs", sa.Column("totp_identifier", sa.String(), nullable=True))
op.add_column("workflows", sa.Column("totp_identifier", sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("workflows", "totp_identifier")
op.drop_column("workflow_runs", "totp_identifier")
op.drop_column("tasks", "totp_identifier")
op.drop_index(op.f("ix_totp_codes_totp_identifier"), table_name="totp_codes")
op.drop_index(op.f("ix_totp_codes_expired_at"), table_name="totp_codes")
op.drop_index(op.f("ix_totp_codes_created_at"), table_name="totp_codes")
op.drop_table("totp_codes")
# ### end Alembic commands ###

View file

@ -117,6 +117,9 @@ class Settings(BaseSettings):
AZURE_GPT4O_MINI_API_BASE: str | None = None
AZURE_GPT4O_MINI_API_VERSION: str | None = None
# TOTP Settings
TOTP_LIFESPAN_MINUTES: int = 10
def is_cloud_environment(self) -> bool:
"""
:return: True if env is not local, else False

View file

@ -122,7 +122,8 @@ class ForgeAgent:
url=task_url,
title=task_block.title,
webhook_callback_url=None,
totp_verification_url=None,
totp_verification_url=task_block.totp_verification_url,
totp_identifier=task_block.totp_identifier,
navigation_goal=task_block.navigation_goal,
data_extraction_goal=task_block.data_extraction_goal,
navigation_payload=navigation_payload,
@ -178,6 +179,7 @@ class ForgeAgent:
title=task_request.title,
webhook_callback_url=task_request.webhook_callback_url,
totp_verification_url=task_request.totp_verification_url,
totp_identifier=task_request.totp_identifier,
navigation_goal=task_request.navigation_goal,
data_extraction_goal=task_request.data_extraction_goal,
navigation_payload=task_request.navigation_payload,
@ -983,7 +985,7 @@ class ForgeAgent:
task,
browser_state,
element_tree_in_prompt,
verification_code_check=bool(task.totp_verification_url),
verification_code_check=bool(task.totp_verification_url or task.totp_identifier),
expire_verification_code=True,
)
@ -1055,7 +1057,7 @@ class ForgeAgent:
final_navigation_payload = task.navigation_payload
current_context = skyvern_context.ensure_context()
verification_code = current_context.totp_codes.get(task.task_id)
if task.totp_verification_url and verification_code:
if (task.totp_verification_url or task.totp_identifier) and verification_code:
if (
isinstance(final_navigation_payload, dict)
and SPECIAL_FIELD_VERIFICATION_CODE not in final_navigation_payload
@ -1598,10 +1600,13 @@ class ForgeAgent:
json_response: dict[str, Any],
) -> dict[str, Any]:
need_verification_code = json_response.get("need_verification_code")
if need_verification_code and task.totp_verification_url and task.organization_id:
if need_verification_code and (task.totp_verification_url or task.totp_identifier) and task.organization_id:
LOG.info("Need verification code", step_id=step.step_id)
verification_code = await poll_verification_code(
task.task_id, task.organization_id, url=task.totp_verification_url
task.task_id,
task.organization_id,
totp_verification_url=task.totp_verification_url,
totp_identifier=task.totp_identifier,
)
current_context = skyvern_context.ensure_context()
current_context.totp_codes[task.task_id] = verification_code

View file

@ -22,6 +22,7 @@ from skyvern.forge.sdk.db.models import (
StepModel,
TaskGenerationModel,
TaskModel,
TOTPCodeModel,
WorkflowModel,
WorkflowParameterModel,
WorkflowRunModel,
@ -48,6 +49,7 @@ from skyvern.forge.sdk.db.utils import (
from skyvern.forge.sdk.models import Organization, OrganizationAuthToken, Step, StepStatus
from skyvern.forge.sdk.schemas.task_generations import TaskGeneration
from skyvern.forge.sdk.schemas.tasks import ProxyLocation, Task, TaskStatus
from skyvern.forge.sdk.schemas.totp_codes import TOTPCode
from skyvern.forge.sdk.workflow.models.parameter import (
AWSSecretParameter,
BitwardenLoginCredentialParameter,
@ -84,6 +86,7 @@ class AgentDB:
navigation_payload: dict[str, Any] | list | str | None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
organization_id: str | None = None,
proxy_location: ProxyLocation | None = None,
extracted_information_schema: dict[str, Any] | list | str | None = None,
@ -101,6 +104,7 @@ class AgentDB:
title=title,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
navigation_goal=navigation_goal,
data_extraction_goal=data_extraction_goal,
navigation_payload=navigation_payload,
@ -819,6 +823,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
persist_browser_session: bool = False,
workflow_permanent_id: str | None = None,
version: int | None = None,
@ -833,6 +838,7 @@ class AgentDB:
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
persist_browser_session=persist_browser_session,
is_saved_task=is_saved_task,
)
@ -1001,6 +1007,7 @@ class AgentDB:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> WorkflowRun:
try:
async with self.Session() as session:
@ -1012,6 +1019,7 @@ class AgentDB:
status="created",
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
)
session.add(workflow_run)
await session.commit()
@ -1439,3 +1447,27 @@ class AgentDB:
if not task_generation:
return None
return TaskGeneration.model_validate(task_generation)
async def get_totp_codes(
self,
organization_id: str,
totp_identifier: str,
valid_lifespan_minutes: int = settings.TOTP_LIFESPAN_MINUTES,
) -> list[TOTPCode]:
"""
1. filter by:
- organization_id
- totp_identifier
2. make sure created_at is within the valid lifespan
3. sort by created_at desc
"""
async with self.Session() as session:
query = (
select(TOTPCodeModel)
.filter_by(organization_id=organization_id)
.filter_by(totp_identifier=totp_identifier)
.filter(TOTPCodeModel.created_at > datetime.utcnow() - timedelta(minutes=valid_lifespan_minutes))
.order_by(TOTPCodeModel.created_at.desc())
)
totp_code = (await session.scalars(query)).all()
return [TOTPCode.model_validate(totp_code) for totp_code in totp_code]

View file

@ -119,6 +119,11 @@ def generate_task_generation_id() -> str:
return f"{TASK_GENERATION_PREFIX}_{int_id}"
def generate_totp_code_id() -> str:
int_id = generate_id()
return f"totp_{int_id}"
def generate_id() -> int:
"""
generate a 64-bit int ID

View file

@ -29,6 +29,7 @@ from skyvern.forge.sdk.db.id import (
generate_step_id,
generate_task_generation_id,
generate_task_id,
generate_totp_code_id,
generate_workflow_id,
generate_workflow_parameter_id,
generate_workflow_permanent_id,
@ -49,6 +50,7 @@ class TaskModel(Base):
status = Column(String, index=True)
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
title = Column(String)
url = Column(String)
navigation_goal = Column(String)
@ -180,6 +182,7 @@ class WorkflowModel(Base):
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
persist_browser_session = Column(Boolean, default=False, nullable=False)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
@ -207,6 +210,7 @@ class WorkflowRunModel(Base):
proxy_location = Column(Enum(ProxyLocation))
webhook_callback_url = Column(String)
totp_verification_url = Column(String)
totp_identifier = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(
@ -392,3 +396,19 @@ class TaskGenerationModel(Base):
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
class TOTPCodeModel(Base):
__tablename__ = "totp_codes"
totp_code_id = Column(String, primary_key=True, default=generate_totp_code_id)
totp_identifier = Column(String, nullable=False, index=True)
organization_id = Column(String, ForeignKey("organizations.organization_id"))
task_id = Column(String, ForeignKey("tasks.task_id"))
workflow_id = Column(String, ForeignKey("workflows.workflow_id"))
content = Column(String, nullable=False)
code = Column(String, nullable=False)
source = Column(String)
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False, index=True)
modified_at = Column(DateTime, default=datetime.datetime.utcnow, onupdate=datetime.datetime.utcnow, nullable=False)
expired_at = Column(DateTime, index=True)

View file

@ -64,6 +64,7 @@ def convert_to_task(task_obj: TaskModel, debug_enabled: bool = False) -> Task:
url=task_obj.url,
webhook_callback_url=task_obj.webhook_callback_url,
totp_verification_url=task_obj.totp_verification_url,
totp_identifier=task_obj.totp_identifier,
navigation_goal=task_obj.navigation_goal,
data_extraction_goal=task_obj.data_extraction_goal,
navigation_payload=task_obj.navigation_payload,
@ -162,6 +163,7 @@ def convert_to_workflow(workflow_model: WorkflowModel, debug_enabled: bool = Fal
workflow_permanent_id=workflow_model.workflow_permanent_id,
webhook_callback_url=workflow_model.webhook_callback_url,
totp_verification_url=workflow_model.totp_verification_url,
totp_identifier=workflow_model.totp_identifier,
persist_browser_session=workflow_model.persist_browser_session,
proxy_location=(ProxyLocation(workflow_model.proxy_location) if workflow_model.proxy_location else None),
version=workflow_model.version,
@ -192,6 +194,7 @@ def convert_to_workflow_run(workflow_run_model: WorkflowRunModel, debug_enabled:
),
webhook_callback_url=workflow_run_model.webhook_callback_url,
totp_verification_url=workflow_run_model.totp_verification_url,
totp_identifier=workflow_run_model.totp_identifier,
created_at=workflow_run_model.created_at,
modified_at=workflow_run_model.modified_at,
)

View file

@ -39,6 +39,7 @@ class TaskRequest(BaseModel):
examples=["https://my-webhook.com"],
)
totp_verification_url: str | None = None
totp_identifier: str | None = None
navigation_goal: str | None = Field(
default=None,
description="The user's goal for the task.",

View file

@ -0,0 +1,29 @@
from datetime import datetime
from pydantic import BaseModel, ConfigDict
class TOTPCodeBase(BaseModel):
model_config = ConfigDict(from_attributes=True)
totp_identifier: str | None = None
organization_id: str | None = None
task_id: str | None = None
workflow_id: str | None = None
source: str | None = None
content: str | None = None
expired_at: datetime | None = None
class TOTPCodeCreate(TOTPCodeBase):
totp_identifier: str
organization_id: str
content: str
class TOTPCode(TOTPCodeCreate):
totp_code_id: str
code: str
created_at: datetime
modified_at: datetime

View file

@ -176,6 +176,8 @@ class TaskBlock(Block):
max_steps_per_run: int | None = None
parameters: list[PARAMETER_TYPE] = []
complete_on_download: bool = False
totp_verification_url: str | None = None
totp_identifier: str | None = None
def get_all_parameters(
self,

View file

@ -15,6 +15,7 @@ class WorkflowRequestBody(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
class RunWorkflowResponse(BaseModel):
@ -51,6 +52,7 @@ class Workflow(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
persist_browser_session: bool = False
created_at: datetime
@ -75,6 +77,7 @@ class WorkflowRun(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
created_at: datetime
modified_at: datetime
@ -101,6 +104,7 @@ class WorkflowRunStatusResponse(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
created_at: datetime
modified_at: datetime
parameters: dict[str, Any]

View file

@ -46,9 +46,7 @@ class BitwardenSensitiveInformationParameterYAML(ParameterYAML):
# Parameter 1 of Literal[...] cannot be of type "Any"
# This pattern already works in block.py but since the ParameterType is not defined in this file, mypy is not able
# to infer the type of the parameter_type attribute.
parameter_type: Literal[ParameterType.BITWARDEN_SENSITIVE_INFORMATION] = (
ParameterType.BITWARDEN_SENSITIVE_INFORMATION
) # type: ignore
parameter_type: Literal["bitwarden_sensitive_information"] = ParameterType.BITWARDEN_SENSITIVE_INFORMATION # type: ignore
# bitwarden cli required fields
bitwarden_client_id_aws_secret_key: str
@ -113,6 +111,8 @@ class TaskBlockYAML(BlockYAML):
max_steps_per_run: int | None = None
parameter_keys: list[str] | None = None
complete_on_download: bool = False
totp_verification_url: str | None = None
totp_identifier: str | None = None
class ForLoopBlockYAML(BlockYAML):
@ -225,6 +225,7 @@ class WorkflowCreateYAMLRequest(BaseModel):
proxy_location: ProxyLocation | None = None
webhook_callback_url: str | None = None
totp_verification_url: str | None = None
totp_identifier: str | None = None
persist_browser_session: bool = False
workflow_definition: WorkflowDefinitionYAML
is_saved_task: bool = False

View file

@ -286,6 +286,7 @@ class WorkflowService:
proxy_location: ProxyLocation | None = None,
webhook_callback_url: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
persist_browser_session: bool = False,
workflow_permanent_id: str | None = None,
version: int | None = None,
@ -299,6 +300,7 @@ class WorkflowService:
proxy_location=proxy_location,
webhook_callback_url=webhook_callback_url,
totp_verification_url=totp_verification_url,
totp_identifier=totp_identifier,
persist_browser_session=persist_browser_session,
workflow_permanent_id=workflow_permanent_id,
version=version,
@ -397,6 +399,7 @@ class WorkflowService:
proxy_location=workflow_request.proxy_location,
webhook_callback_url=workflow_request.webhook_callback_url,
totp_verification_url=workflow_request.totp_verification_url,
totp_identifier=workflow_request.totp_identifier,
)
async def mark_workflow_run_as_completed(self, workflow_run_id: str) -> None:
@ -640,6 +643,7 @@ class WorkflowService:
proxy_location=workflow_run.proxy_location,
webhook_callback_url=workflow_run.webhook_callback_url,
totp_verification_url=workflow_run.totp_verification_url,
totp_identifier=workflow_run.totp_identifier,
created_at=workflow_run.created_at,
modified_at=workflow_run.modified_at,
parameters=parameters_with_value,
@ -835,6 +839,7 @@ class WorkflowService:
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
totp_verification_url=request.totp_verification_url,
totp_identifier=request.totp_identifier,
persist_browser_session=request.persist_browser_session,
workflow_permanent_id=workflow_permanent_id,
version=existing_version + 1,
@ -849,6 +854,7 @@ class WorkflowService:
proxy_location=request.proxy_location,
webhook_callback_url=request.webhook_callback_url,
totp_verification_url=request.totp_verification_url,
totp_identifier=request.totp_identifier,
persist_browser_session=request.persist_browser_session,
is_saved_task=request.is_saved_task,
)
@ -912,7 +918,8 @@ class WorkflowService:
bitwarden_client_id_aws_secret_key=parameter.bitwarden_client_id_aws_secret_key,
bitwarden_client_secret_aws_secret_key=parameter.bitwarden_client_secret_aws_secret_key,
bitwarden_master_password_aws_secret_key=parameter.bitwarden_master_password_aws_secret_key,
bitwarden_collection_id=parameter.bitwarden_collection_id,
# TODO: remove "# type: ignore" after ensuring bitwarden_collection_id is always set
bitwarden_collection_id=parameter.bitwarden_collection_id, # type: ignore
bitwarden_identity_key=parameter.bitwarden_identity_key,
bitwarden_identity_fields=parameter.bitwarden_identity_fields,
key=parameter.key,
@ -1046,6 +1053,8 @@ class WorkflowService:
max_retries=block_yaml.max_retries,
complete_on_download=block_yaml.complete_on_download,
continue_on_failure=block_yaml.continue_on_failure,
totp_verification_url=block_yaml.totp_verification_url,
totp_identifier=block_yaml.totp_identifier,
)
elif block_yaml.block_type == BlockType.FOR_LOOP:
loop_blocks = [

View file

@ -1931,7 +1931,13 @@ async def get_input_value(tag_name: str, locator: Locator) -> str | None:
return await locator.inner_text()
async def poll_verification_code(task_id: str, organization_id: str, url: str) -> str | None:
async def poll_verification_code(
task_id: str,
organization_id: str,
workflow_id: str | None = None,
totp_verification_url: str | None = None,
totp_identifier: str | None = None,
) -> str | None:
timeout = timedelta(minutes=VERIFICATION_CODE_POLLING_TIMEOUT_MINS)
start_datetime = datetime.utcnow()
timeout_datetime = start_datetime + timeout
@ -1943,13 +1949,28 @@ async def poll_verification_code(task_id: str, organization_id: str, url: str) -
# check timeout
if datetime.utcnow() > timeout_datetime:
return None
verification_code = None
if totp_verification_url:
verification_code = await _get_verification_code_from_url(task_id, totp_verification_url, org_token.token)
elif totp_identifier:
verification_code = await _get_verification_code_from_db(
task_id, organization_id, totp_identifier, workflow_id=workflow_id
)
if verification_code:
LOG.info("Got verification code", verification_code=verification_code)
return verification_code
await asyncio.sleep(10)
async def _get_verification_code_from_url(task_id: str, url: str, api_key: str) -> str | None:
request_data = {
"task_id": task_id,
}
payload = json.dumps(request_data)
signature = generate_skyvern_signature(
payload=payload,
api_key=org_token.token,
api_key=api_key,
)
timestamp = str(int(datetime.utcnow().timestamp()))
headers = {
@ -1958,9 +1979,22 @@ async def poll_verification_code(task_id: str, organization_id: str, url: str) -
"Content-Type": "application/json",
}
json_resp = await aiohttp_post(url=url, data=request_data, headers=headers, raise_exception=False)
verification_code = json_resp.get("verification_code", None)
if verification_code:
LOG.info("Got verification code", verification_code=verification_code)
return verification_code
return json_resp.get("verification_code", None)
await asyncio.sleep(10)
async def _get_verification_code_from_db(
task_id: str,
organization_id: str,
totp_identifier: str,
workflow_id: str | None = None,
) -> str | None:
totp_codes = await app.DATABASE.get_totp_codes(organization_id=organization_id, totp_identifier=totp_identifier)
for totp_code in totp_codes:
if totp_code.workflow_id and workflow_id and totp_code.workflow_id != workflow_id:
continue
if totp_code.task_id and totp_code.task_id != task_id:
continue
if totp_code.expired_at and totp_code.expired_at < datetime.utcnow():
continue
return totp_code.code
return None