mirror of
https://github.com/Skyvern-AI/skyvern.git
synced 2026-04-28 03:30:10 +00:00
112 lines
3.9 KiB
Python
112 lines
3.9 KiB
Python
import asyncio
|
|
import contextvars
|
|
from contextlib import asynccontextmanager
|
|
from functools import wraps
|
|
from typing import Any, AsyncContextManager, AsyncIterator, Callable
|
|
|
|
import structlog
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
|
|
|
LOG = structlog.get_logger()
|
|
|
|
|
|
def read_retry(retries: int = 3) -> Callable:
|
|
"""Decorator to retry async database operations on transient failures.
|
|
|
|
Args:
|
|
retries: Maximum number of retry attempts (default: 3)
|
|
"""
|
|
|
|
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
|
|
@wraps(fn)
|
|
async def wrapper(
|
|
base_db: "BaseAlchemyDB",
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
for attempt in range(retries):
|
|
try:
|
|
return await fn(base_db, *args, **kwargs)
|
|
except SQLAlchemyError as e:
|
|
if not base_db.is_retryable_error(e):
|
|
LOG.error("SQLAlchemyError", exc_info=True, attempt=attempt)
|
|
raise
|
|
if attempt >= retries - 1:
|
|
LOG.error("SQLAlchemyError after all retries", exc_info=True, attempt=attempt)
|
|
raise
|
|
|
|
backoff_time = 0.2 * (2**attempt)
|
|
LOG.warning(
|
|
"SQLAlchemyError retrying",
|
|
attempt=attempt,
|
|
backoff_time=backoff_time,
|
|
exc_info=True,
|
|
)
|
|
await asyncio.sleep(backoff_time)
|
|
|
|
except Exception:
|
|
LOG.error("UnexpectedError", exc_info=True)
|
|
raise
|
|
|
|
raise RuntimeError(f"Retry logic error in {fn.__name__}")
|
|
|
|
return wrapper
|
|
|
|
return decorator
|
|
|
|
|
|
class BaseAlchemyDB:
|
|
"""Base database client with connection and session management."""
|
|
|
|
def __init__(self, db_engine: AsyncEngine) -> None:
|
|
self.engine = db_engine
|
|
self.Session = _SessionFactory(self, async_sessionmaker(bind=db_engine))
|
|
|
|
def is_retryable_error(self, error: SQLAlchemyError) -> bool:
|
|
"""Check if a database error is retryable. Override in subclasses for specific error handling."""
|
|
return False
|
|
|
|
|
|
class _SessionFactory:
|
|
def __init__(self, db: BaseAlchemyDB, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
|
self._db = db
|
|
self._sessionmaker = sessionmaker
|
|
self._session_ctx: contextvars.ContextVar[AsyncSession | None] = contextvars.ContextVar(
|
|
"skyvern_db_session",
|
|
default=None,
|
|
)
|
|
|
|
def __call__(self) -> AsyncContextManager[AsyncSession]:
|
|
return self._session()
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
return getattr(self._sessionmaker, name)
|
|
|
|
@asynccontextmanager
|
|
async def _session(self) -> AsyncIterator[AsyncSession]:
|
|
existing_session = self._session_ctx.get()
|
|
if existing_session is not None:
|
|
yield existing_session
|
|
return
|
|
|
|
session = self._sessionmaker()
|
|
token = self._session_ctx.set(session)
|
|
try:
|
|
yield session
|
|
finally:
|
|
self._session_ctx.reset(token)
|
|
try:
|
|
await session.close()
|
|
except SQLAlchemyError as e:
|
|
# Handle transient errors during session cleanup gracefully.
|
|
# This can happen on replicas when the connection is terminated due to
|
|
# WAL replay conflicts. Since the actual DB operation already completed
|
|
# successfully (we're in finally block cleanup), we just log and continue.
|
|
if self._db.is_retryable_error(e):
|
|
LOG.warning(
|
|
"Transient error during session close (suppressed)",
|
|
error=str(e),
|
|
)
|
|
else:
|
|
raise
|