refactor objectmodel

This commit is contained in:
LUIS NOVO 2024-11-19 19:03:32 -03:00
parent f140a5e228
commit c297dcb809
8 changed files with 186 additions and 68 deletions

View file

@ -1,8 +1,22 @@
from datetime import datetime
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast
from typing import (
Any,
ClassVar,
Dict,
List,
Optional,
Type,
TypeVar,
cast,
)
from loguru import logger
from pydantic import BaseModel, ValidationError, field_validator
from pydantic import (
BaseModel,
ValidationError,
field_validator,
model_validator,
)
from open_notebook.database.repository import (
repo_create,
@ -10,6 +24,7 @@ from open_notebook.database.repository import (
repo_query,
repo_relate,
repo_update,
repo_upsert,
)
from open_notebook.exceptions import (
DatabaseOperationError,
@ -204,24 +219,92 @@ class ObjectModel(BaseModel):
class RecordModel(BaseModel):
record_id: ClassVar[str]
auto_save: ClassVar[bool] = (
False # Default to False, can be overridden in subclasses
)
_instances: ClassVar[Dict[str, "RecordModel"]] = {} # Store instances by record_id
class Config:
validate_assignment = True
arbitrary_types_allowed = True
extra = "allow"
from_attributes = True
defer_build = True
def __new__(cls, **kwargs):
# If an instance already exists for this record_id, return it
if cls.record_id in cls._instances:
instance = cls._instances[cls.record_id]
# Update instance with any new kwargs if provided
if kwargs:
for key, value in kwargs.items():
setattr(instance, key, value)
return instance
# If no instance exists, create a new one
instance = super().__new__(cls)
cls._instances[cls.record_id] = instance
return instance
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.load()
# Only initialize if this is a new instance
if not hasattr(self, "_initialized"):
object.__setattr__(self, "__dict__", {})
# Load data from DB first
result = repo_query(f"SELECT * FROM {self.record_id};")
if result:
db_data = result[0]
else:
# Initialize empty object with None for Optional fields
db_data = {
field_name: None
for field_name, field_info in self.model_fields.items()
if not str(field_info.annotation).startswith("typing.ClassVar")
}
# Initialize with DB data and any overrides
super().__init__(**{**db_data, **kwargs})
object.__setattr__(self, "_initialized", True)
@classmethod
def get_instance(cls) -> "RecordModel":
"""Get or create the singleton instance"""
return cls()
@model_validator(mode="after")
def auto_save_validator(self):
if self.__class__.auto_save:
self.update()
return self
def update(self):
# Get all non-ClassVar fields and their values
data = {
field_name: getattr(self, field_name)
for field_name, field_info in self.model_fields.items()
if not str(field_info.annotation).startswith("typing.ClassVar")
}
repo_upsert(self.record_id, data)
def load(self):
result = repo_query(f"SELECT * FROM {self.record_id};")
if result:
result = result[0]
else:
repo_create(self.record_id, {})
result = {}
for key, value in result.items():
if hasattr(self, key):
setattr(self, key, value)
for key, value in result[0].items():
if hasattr(self, key):
object.__setattr__(
self, key, value
) # Use object.__setattr__ to avoid triggering validation again
return self
def update(self, data):
repo_update(self.record_id, data)
return self.load()
@classmethod
def clear_instance(cls):
"""Clear the singleton instance (useful for testing)"""
if cls.record_id in cls._instances:
del cls._instances[cls.record_id]
def patch(self, model_dict: dict):
"""Update model attributes from dictionary and save"""
for key, value in model_dict.items():
setattr(self, key, value)
self.update()