import os from contextlib import asynccontextmanager from datetime import datetime, timezone from typing import Any, Dict, List, Optional, TypeVar, Union from loguru import logger from surrealdb import AsyncSurreal, RecordID # type: ignore T = TypeVar("T", Dict[str, Any], List[Dict[str, Any]]) def get_database_url(): """Get database URL with backward compatibility""" surreal_url = os.getenv("SURREAL_URL") if surreal_url: return surreal_url # Fallback to old format - WebSocket URL format address = os.getenv("SURREAL_ADDRESS", "localhost") port = os.getenv("SURREAL_PORT", "8000") return f"ws://{address}/rpc:{port}" def get_database_password(): """Get password with backward compatibility""" return os.getenv("SURREAL_PASSWORD") or os.getenv("SURREAL_PASS") def parse_record_ids(obj: Any) -> Any: """Recursively parse and convert RecordIDs into strings.""" if isinstance(obj, dict): return {k: parse_record_ids(v) for k, v in obj.items()} elif isinstance(obj, list): return [parse_record_ids(item) for item in obj] elif isinstance(obj, RecordID): return str(obj) return obj def ensure_record_id(value: Union[str, RecordID]) -> RecordID: """Ensure a value is a RecordID.""" if isinstance(value, RecordID): return value return RecordID.parse(value) @asynccontextmanager async def db_connection(): db = AsyncSurreal(get_database_url()) await db.signin( { "username": os.environ.get("SURREAL_USER"), "password": get_database_password(), } ) await db.use( os.environ.get("SURREAL_NAMESPACE"), os.environ.get("SURREAL_DATABASE") ) try: yield db finally: await db.close() async def repo_query( query_str: str, vars: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Execute a SurrealQL query and return the results""" async with db_connection() as connection: try: result = parse_record_ids(await connection.query(query_str, vars)) if isinstance(result, str): raise RuntimeError(result) return result except RuntimeError as e: # RuntimeError is raised for retriable transaction conflicts - log without stack trace logger.error(str(e)) raise except Exception as e: logger.exception(e) raise async def repo_create(table: str, data: Dict[str, Any]) -> Dict[str, Any]: """Create a new record in the specified table""" # Remove 'id' attribute if it exists in data data.pop("id", None) data["created"] = datetime.now(timezone.utc) data["updated"] = datetime.now(timezone.utc) try: async with db_connection() as connection: return parse_record_ids(await connection.insert(table, data)) except RuntimeError as e: logger.error(str(e)) raise except Exception as e: logger.exception(e) raise RuntimeError("Failed to create record") async def repo_relate( source: str, relationship: str, target: str, data: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Create a relationship between two records with optional data""" if data is None: data = {} query = f"RELATE {source}->{relationship}->{target} CONTENT $data;" # logger.debug(f"Relate query: {query}") return await repo_query( query, { "data": data, }, ) async def repo_upsert( table: str, id: Optional[str], data: Dict[str, Any], add_timestamp: bool = False ) -> List[Dict[str, Any]]: """Create or update a record in the specified table""" data.pop("id", None) if add_timestamp: data["updated"] = datetime.now(timezone.utc) query = f"UPSERT {id if id else table} MERGE $data;" return await repo_query(query, {"data": data}) async def repo_update( table: str, id: str, data: Dict[str, Any] ) -> List[Dict[str, Any]]: """Update an existing record by table and id""" # If id already contains the table name, use it as is try: if isinstance(id, RecordID) or (":" in id and id.startswith(f"{table}:")): record_id = id else: record_id = f"{table}:{id}" data.pop("id", None) if "created" in data and isinstance(data["created"], str): data["created"] = datetime.fromisoformat(data["created"]) data["updated"] = datetime.now(timezone.utc) query = f"UPDATE {record_id} MERGE $data;" # logger.debug(f"Update query: {query}") result = await repo_query(query, {"data": data}) # if isinstance(result, list): # return [_return_data(item) for item in result] return parse_record_ids(result) except Exception as e: raise RuntimeError(f"Failed to update record: {str(e)}") async def repo_delete(record_id: Union[str, RecordID]): """Delete a record by record id""" try: async with db_connection() as connection: return await connection.delete(ensure_record_id(record_id)) except Exception as e: logger.exception(e) raise RuntimeError(f"Failed to delete record: {str(e)}") async def repo_insert( table: str, data: List[Dict[str, Any]], ignore_duplicates: bool = False ) -> List[Dict[str, Any]]: """Create a new record in the specified table""" try: async with db_connection() as connection: return parse_record_ids(await connection.insert(table, data)) except Exception as e: if ignore_duplicates and "already contains" in str(e): return [] logger.exception(e) raise RuntimeError("Failed to create record")