Implement get_share_links (#302)

This commit is contained in:
Kerem Yilmaz 2024-05-13 00:03:31 -07:00 committed by GitHub
parent 6a83f367ba
commit 20a86590dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 41 additions and 38 deletions

View file

@ -888,16 +888,10 @@ class ForgeAgent:
artifact_types=[ArtifactType.SCREENSHOT_ACTION],
n=SettingsManager.get_settings().TASK_RESPONSE_ACTION_SCREENSHOT_COUNT,
)
latest_action_screenshot_urls = []
if latest_action_screenshot_artifacts:
for artifact in latest_action_screenshot_artifacts:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(artifact)
if screenshot_url:
latest_action_screenshot_urls.append(screenshot_url)
else:
LOG.error(
"Failed to get share link for action screenshot",
artifact_id=artifact.artifact_id,
latest_action_screenshot_urls = await app.ARTIFACT_MANAGER.get_share_links(
latest_action_screenshot_artifacts
)
else:
LOG.error("Failed to get latest action screenshots")

View file

@ -75,17 +75,21 @@ class AsyncAWSClient:
return None
@execute_with_async_client(client_type=AWSClientType.S3)
async def create_presigned_url(self, uri: str, client: AioBaseClient = None) -> str | None:
async def create_presigned_urls(self, uris: list[str], client: AioBaseClient = None) -> list[str] | None:
presigned_urls = []
try:
for uri in uris:
parsed_uri = S3Uri(uri)
url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": parsed_uri.bucket, "Key": parsed_uri.key},
ExpiresIn=SettingsManager.get_settings().PRESIGNED_URL_EXPIRATION,
)
return url
presigned_urls.append(url)
return presigned_urls
except Exception:
LOG.exception("Failed to create presigned url.", uri=uri)
LOG.exception("Failed to create presigned url for S3 objects.", uris=uris)
return None

View file

@ -71,6 +71,9 @@ class ArtifactManager:
async def get_share_link(self, artifact: Artifact) -> str | None:
return await app.STORAGE.get_share_link(artifact)
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
return await app.STORAGE.get_share_links(artifacts)
async def wait_for_upload_aiotasks_for_task(self, task_id: str) -> None:
try:
st = time.time()

View file

@ -40,6 +40,10 @@ class BaseStorage(ABC):
async def get_share_link(self, artifact: Artifact) -> str | None:
pass
@abstractmethod
async def get_share_links(self, artifacts: list[Artifact]) -> list[str] | None:
pass
@abstractmethod
async def store_artifact_from_path(self, artifact: Artifact, path: str) -> None:
pass

View file

@ -52,6 +52,9 @@ class LocalStorage(BaseStorage):
async def get_share_link(self, artifact: Artifact) -> str:
return artifact.uri
async def get_share_links(self, artifacts: list[Artifact]) -> list[str]:
return [artifact.uri for artifact in artifacts]
@staticmethod
def _parse_uri_to_path(uri: str) -> str:
parsed_uri = urlparse(uri)

View file

@ -1,4 +1,3 @@
import asyncio
from typing import Annotated, Any
import structlog
@ -238,17 +237,9 @@ async def get_task(
artifact_types=[ArtifactType.SCREENSHOT_ACTION],
n=SettingsManager.get_settings().TASK_RESPONSE_ACTION_SCREENSHOT_COUNT,
)
latest_action_screenshot_urls = []
latest_action_screenshot_urls: list[str] | None = None
if latest_action_screenshot_artifacts:
for artifact in latest_action_screenshot_artifacts:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(artifact)
if screenshot_url:
latest_action_screenshot_urls.append(screenshot_url)
else:
LOG.error(
"Failed to get share link for action screenshot",
artifact_id=artifact.artifact_id,
)
latest_action_screenshot_urls = await app.ARTIFACT_MANAGER.get_share_links(latest_action_screenshot_artifacts)
elif task_obj.status in [TaskStatus.terminated, TaskStatus.completed]:
LOG.error(
"Failed to get latest action screenshots in task response",
@ -416,9 +407,12 @@ async def get_agent_task_step_artifacts(
organization_id=current_org.organization_id,
)
if SettingsManager.get_settings().ENV != "local":
signed_urls = await asyncio.gather(*[app.ARTIFACT_MANAGER.get_share_link(artifact) for artifact in artifacts])
signed_urls = await app.ARTIFACT_MANAGER.get_share_links(artifacts)
if signed_urls:
for i, artifact in enumerate(artifacts):
artifact.signed_url = signed_urls[i]
else:
LOG.error("Failed to get signed urls for artifacts", task_id=task_id, step_id=step_id)
return ORJSONResponse([artifact.model_dump() for artifact in artifacts])

View file

@ -451,7 +451,8 @@ class WorkflowService:
workflow_run = await self.get_workflow_run(workflow_run_id=workflow_run_id)
workflow_run_tasks = await app.DATABASE.get_tasks_by_workflow_run_id(workflow_run_id=workflow_run_id)
screenshot_urls = []
screenshot_artifacts = []
screenshot_urls: list[str] | None = None
# get the last screenshot for the last 3 tasks of the workflow run
for task in workflow_run_tasks[::-1]:
screenshot_artifact = await app.DATABASE.get_latest_artifact(
@ -460,11 +461,11 @@ class WorkflowService:
organization_id=organization_id,
)
if screenshot_artifact:
screenshot_url = await app.ARTIFACT_MANAGER.get_share_link(screenshot_artifact)
if screenshot_url:
screenshot_urls.append(screenshot_url)
if len(screenshot_urls) >= 3:
screenshot_artifacts.append(screenshot_artifact)
if len(screenshot_artifacts) >= 3:
break
if screenshot_artifacts:
screenshot_urls = await app.ARTIFACT_MANAGER.get_share_links(screenshot_artifacts)
recording_url = None
recording_artifact = await app.DATABASE.get_artifact_for_workflow_run(