from datetime import datetime from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast from loguru import logger from pydantic import BaseModel, ValidationError, field_validator from open_notebook.database.repository import ( repo_create, repo_delete, repo_query, repo_relate, repo_update, ) from open_notebook.exceptions import ( DatabaseOperationError, InvalidInputError, NotFoundError, ) T = TypeVar("T", bound="ObjectModel") class ObjectModel(BaseModel): id: Optional[str] = None table_name: ClassVar[str] = "" created: Optional[datetime] = None updated: Optional[datetime] = None @classmethod def get_all(cls: Type[T], order_by=None) -> List[T]: try: # If called from a specific subclass, use its table_name if cls.table_name: target_class = cls table_name = cls.table_name else: # This path is taken if called directly from ObjectModel raise InvalidInputError( "get_all() must be called from a specific model class" ) if order_by: order = f" ORDER BY {order_by}" else: order = "" result = repo_query(f"SELECT * FROM {table_name} {order}") objects = [] for obj in result: try: objects.append(target_class(**obj)) except Exception as e: logger.critical(f"Error creating object: {str(e)}") return objects except Exception as e: logger.error(f"Error fetching all {cls.table_name}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) @classmethod def get(cls: Type[T], id: str) -> T: if not id: raise InvalidInputError("ID cannot be empty") try: # Get the table name from the ID (everything before the first colon) table_name = id.split(":")[0] if ":" in id else id # If we're calling from a specific subclass and IDs match, use that class if cls.table_name and cls.table_name == table_name: target_class: Type[T] = cls else: # Otherwise, find the appropriate subclass based on table_name found_class = cls._get_class_by_table_name(table_name) if not found_class: raise InvalidInputError(f"No class found for table {table_name}") target_class = cast(Type[T], found_class) result = repo_query(f"SELECT * FROM {id}") if result: return target_class(**result[0]) else: raise NotFoundError(f"{table_name} with id {id} not found") except Exception as e: logger.error(f"Error fetching object with id {id}: {str(e)}") logger.exception(e) raise NotFoundError(f"Object with id {id} not found - {str(e)}") @classmethod def _get_class_by_table_name(cls, table_name: str) -> Optional[Type["ObjectModel"]]: """Find the appropriate subclass based on table_name.""" def get_all_subclasses(c: Type["ObjectModel"]) -> List[Type["ObjectModel"]]: all_subclasses: List[Type["ObjectModel"]] = [] for subclass in c.__subclasses__(): all_subclasses.append(subclass) all_subclasses.extend(get_all_subclasses(subclass)) return all_subclasses for subclass in get_all_subclasses(ObjectModel): if hasattr(subclass, "table_name") and subclass.table_name == table_name: return subclass return None def needs_embedding(self) -> bool: return False def get_embedding_content(self) -> Optional[str]: return None def save(self) -> None: from open_notebook.domain.models import model_manager from open_notebook.models import EmbeddingModel try: self.model_validate(self.model_dump(), strict=True) data = self._prepare_save_data() data["updated"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if self.needs_embedding(): embedding_content = self.get_embedding_content() if embedding_content: EMBEDDING_MODEL: EmbeddingModel = model_manager.embedding_model data["embedding"] = EMBEDDING_MODEL.embed(embedding_content) if self.id is None: data["created"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") repo_result = repo_create(self.__class__.table_name, data) else: data["created"] = ( self.created.strftime("%Y-%m-%d %H:%M:%S") if isinstance(self.created, datetime) else self.created ) logger.debug(f"Updating record with id {self.id}") repo_result = repo_update(self.id, data) # Update the current instance with the result for key, value in repo_result[0].items(): if hasattr(self, key): if isinstance(getattr(self, key), BaseModel): setattr(self, key, type(getattr(self, key))(**value)) else: setattr(self, key, value) except ValidationError as e: logger.error(f"Validation failed: {e}") raise except Exception as e: logger.error(f"Error saving record: {e}") raise except Exception as e: logger.error(f"Error saving {self.__class__.table_name}: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) def _prepare_save_data(self) -> Dict[str, Any]: data = self.model_dump() return {key: value for key, value in data.items() if value is not None} def delete(self) -> bool: if self.id is None: raise InvalidInputError("Cannot delete object without an ID") try: logger.debug(f"Deleting record with id {self.id}") return repo_delete(self.id) except Exception as e: logger.error( f"Error deleting {self.__class__.table_name} with id {self.id}: {str(e)}" ) raise DatabaseOperationError( f"Failed to delete {self.__class__.table_name}" ) def relate(self, relationship: str, target_id: str) -> Any: if not relationship or not target_id or not self.id: raise InvalidInputError("Relationship and target ID must be provided") try: return repo_relate(self.id, relationship, target_id) except Exception as e: logger.error(f"Error creating relationship: {str(e)}") logger.exception(e) raise DatabaseOperationError(e) @field_validator("created", "updated", mode="before") @classmethod def parse_datetime(cls, value): if isinstance(value, str): return datetime.fromisoformat(value.replace("Z", "+00:00")) return value class RecordModel(BaseModel): record_id: ClassVar[str] def __init__(self, **kwargs): super().__init__(**kwargs) self.load() def load(self): result = repo_query(f"SELECT * FROM {self.record_id};") if result: result = result[0] else: repo_create(self.record_id, {}) result = {} for key, value in result.items(): if hasattr(self, key): setattr(self, key, value) return self def update(self, data): repo_update(self.record_id, data) return self.load()