task v2 refactor part 6 - observer_cruise_id -> task_v2_id (#1817)

This commit is contained in:
Shuchang Zheng 2025-02-23 16:03:49 -08:00 committed by GitHub
parent 2d24055c36
commit ffbc95e1b4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 238 additions and 250 deletions

View file

@ -210,7 +210,7 @@ class AgentDB:
task_id: str | None = None,
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
observer_thought_id: str | None = None,
ai_suggestion_id: str | None = None,
organization_id: str | None = None,
@ -225,7 +225,7 @@ class AgentDB:
step_id=step_id,
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_cruise_id=observer_cruise_id,
observer_cruise_id=task_v2_id,
observer_thought_id=observer_thought_id,
ai_suggestion_id=ai_suggestion_id,
organization_id=organization_id,
@ -807,9 +807,9 @@ class AgentDB:
return convert_to_organization_auth_token(auth_token)
async def get_artifacts_for_observer_cruise(
async def get_artifacts_for_task_v2(
self,
observer_cruise_id: str,
task_v2_id: str,
organization_id: str | None = None,
artifact_types: list[ArtifactType] | None = None,
) -> list[Artifact]:
@ -817,7 +817,7 @@ class AgentDB:
async with self.Session() as session:
query = (
select(ArtifactModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
if artifact_types:
@ -894,7 +894,7 @@ class AgentDB:
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
organization_id: str | None = None,
) -> list[Artifact]:
try:
@ -913,8 +913,8 @@ class AgentDB:
query = query.filter_by(workflow_run_block_id=workflow_run_block_id)
if observer_thought_id is not None:
query = query.filter_by(observer_thought_id=observer_thought_id)
if observer_cruise_id is not None:
query = query.filter_by(observer_cruise_id=observer_cruise_id)
if task_v2_id is not None:
query = query.filter_by(observer_cruise_id=task_v2_id)
if organization_id is not None:
query = query.filter_by(organization_id=organization_id)
@ -938,7 +938,7 @@ class AgentDB:
workflow_run_id: str | None = None,
workflow_run_block_id: str | None = None,
observer_thought_id: str | None = None,
observer_cruise_id: str | None = None,
task_v2_id: str | None = None,
organization_id: str | None = None,
) -> Artifact | None:
artifacts = await self.get_artifacts_by_entity_id(
@ -948,7 +948,7 @@ class AgentDB:
workflow_run_id=workflow_run_id,
workflow_run_block_id=workflow_run_block_id,
observer_thought_id=observer_thought_id,
observer_cruise_id=observer_cruise_id,
task_v2_id=task_v2_id,
organization_id=organization_id,
)
return artifacts[0] if artifacts else None
@ -1915,13 +1915,11 @@ class AgentDB:
await session.execute(stmt)
await session.commit()
async def delete_observer_cruise_artifacts(
self, observer_cruise_id: str, organization_id: str | None = None
) -> None:
async def delete_task_v2_artifacts(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ArtifactModel).where(
and_(
ArtifactModel.observer_cruise_id == observer_cruise_id,
ArtifactModel.observer_cruise_id == task_v2_id,
ArtifactModel.organization_id == organization_id,
)
)
@ -2130,47 +2128,43 @@ class AgentDB:
await session.execute(stmt)
await session.commit()
async def get_observer_cruise(
self, observer_cruise_id: str, organization_id: str | None = None
) -> ObserverTask | None:
async def get_task_v2(self, task_v2_id: str, organization_id: str | None = None) -> ObserverTask | None:
async with self.Session() as session:
if observer_cruise := (
if task_v2 := (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first():
return ObserverTask.model_validate(observer_cruise)
return ObserverTask.model_validate(task_v2)
return None
async def delete_observer_thoughts_for_cruise(
self, observer_cruise_id: str, organization_id: str | None = None
) -> None:
async def delete_observer_thoughts_for_cruise(self, task_v2_id: str, organization_id: str | None = None) -> None:
async with self.Session() as session:
stmt = delete(ObserverThoughtModel).where(
and_(
ObserverThoughtModel.observer_cruise_id == observer_cruise_id,
ObserverThoughtModel.observer_cruise_id == task_v2_id,
ObserverThoughtModel.organization_id == organization_id,
)
)
await session.execute(stmt)
await session.commit()
async def get_observer_cruise_by_workflow_run_id(
async def get_task_v2_by_workflow_run_id(
self,
workflow_run_id: str,
organization_id: str | None = None,
) -> ObserverTask | None:
async with self.Session() as session:
if observer_cruise := (
if task_v2 := (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(organization_id=organization_id)
.filter_by(workflow_run_id=workflow_run_id)
)
).first():
return ObserverTask.model_validate(observer_cruise)
return ObserverTask.model_validate(task_v2)
return None
async def get_observer_thought(
@ -2189,14 +2183,14 @@ class AgentDB:
async def get_observer_thoughts(
self,
observer_cruise_id: str,
task_v2_id: str,
observer_thought_types: list[ObserverThoughtType] | None = None,
organization_id: str | None = None,
) -> list[ObserverThought]:
async with self.Session() as session:
query = (
select(ObserverThoughtModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
.order_by(ObserverThoughtModel.created_at)
)
@ -2205,7 +2199,7 @@ class AgentDB:
observer_thoughts = (await session.scalars(query)).all()
return [ObserverThought.model_validate(thought) for thought in observer_thoughts]
async def create_observer_cruise(
async def create_task_v2(
self,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
@ -2219,7 +2213,7 @@ class AgentDB:
webhook_callback_url: str | None = None,
) -> ObserverTask:
async with self.Session() as session:
new_observer_cruise = ObserverCruiseModel(
new_task_v2 = ObserverCruiseModel(
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
@ -2231,14 +2225,14 @@ class AgentDB:
webhook_callback_url=webhook_callback_url,
organization_id=organization_id,
)
session.add(new_observer_cruise)
session.add(new_task_v2)
await session.commit()
await session.refresh(new_observer_cruise)
return ObserverTask.model_validate(new_observer_cruise)
await session.refresh(new_task_v2)
return ObserverTask.model_validate(new_task_v2)
async def create_observer_thought(
self,
observer_cruise_id: str,
task_v2_id: str,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
workflow_permanent_id: str | None = None,
@ -2257,7 +2251,7 @@ class AgentDB:
) -> ObserverThought:
async with self.Session() as session:
new_observer_thought = ObserverThoughtModel(
observer_cruise_id=observer_cruise_id,
observer_cruise_id=task_v2_id,
workflow_run_id=workflow_run_id,
workflow_id=workflow_id,
workflow_permanent_id=workflow_permanent_id,
@ -2331,9 +2325,9 @@ class AgentDB:
return ObserverThought.model_validate(observer_thought)
raise NotFoundError(f"ObserverThought {observer_thought_id}")
async def update_observer_cruise(
async def update_task_v2(
self,
observer_cruise_id: str,
task_v2_id: str,
status: ObserverTaskStatus | None = None,
workflow_run_id: str | None = None,
workflow_id: str | None = None,
@ -2345,34 +2339,34 @@ class AgentDB:
organization_id: str | None = None,
) -> ObserverTask:
async with self.Session() as session:
observer_cruise = (
task_v2 = (
await session.scalars(
select(ObserverCruiseModel)
.filter_by(observer_cruise_id=observer_cruise_id)
.filter_by(observer_cruise_id=task_v2_id)
.filter_by(organization_id=organization_id)
)
).first()
if observer_cruise:
if task_v2:
if status:
observer_cruise.status = status
task_v2.status = status
if workflow_run_id:
observer_cruise.workflow_run_id = workflow_run_id
task_v2.workflow_run_id = workflow_run_id
if workflow_id:
observer_cruise.workflow_id = workflow_id
task_v2.workflow_id = workflow_id
if workflow_permanent_id:
observer_cruise.workflow_permanent_id = workflow_permanent_id
task_v2.workflow_permanent_id = workflow_permanent_id
if url:
observer_cruise.url = url
task_v2.url = url
if prompt:
observer_cruise.prompt = prompt
task_v2.prompt = prompt
if summary:
observer_cruise.summary = summary
task_v2.summary = summary
if output:
observer_cruise.output = output
task_v2.output = output
await session.commit()
await session.refresh(observer_cruise)
return ObserverTask.model_validate(observer_cruise)
raise NotFoundError(f"ObserverTask {observer_cruise_id} not found")
await session.refresh(task_v2)
return ObserverTask.model_validate(task_v2)
raise NotFoundError(f"TaskV2 {task_v2_id} not found")
async def create_workflow_run_block(
self,