add functionality to cache task_run (#1755)

This commit is contained in:
Shuchang Zheng 2025-02-11 14:47:41 +08:00 committed by GitHub
parent 8c43e6b70e
commit defd761e58
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 127 additions and 18 deletions

View file

@ -2672,3 +2672,30 @@ class AgentDB:
await session.commit()
await session.refresh(task_run)
return TaskRun.model_validate(task_run)
async def cache_task_run(self, run_id: str, organization_id: str | None = None) -> TaskRun:
async with self.Session() as session:
task_run = await session.scalars(
select(TaskRunModel).filter_by(organization_id=organization_id).filter_by(run_id=run_id)
).first()
if task_run:
task_run.cached = True
await session.commit()
await session.refresh(task_run)
return TaskRun.model_validate(task_run)
raise NotFoundError(f"TaskRun {run_id} not found")
async def get_cached_task_run(
self, task_run_type: TaskRunType, url_hash: str | None = None, organization_id: str | None = None
) -> TaskRun | None:
async with self.Session() as session:
query = select(TaskRunModel)
if task_run_type:
query = query.filter_by(task_run_type=task_run_type)
if url_hash:
query = query.filter_by(url_hash=url_hash)
if organization_id:
query = query.filter_by(organization_id=organization_id)
query = query.filter_by(cached=True).order_by(TaskRunModel.created_at.desc())
task_run = await session.scalars(query).first()
return TaskRun.model_validate(task_run) if task_run else None