mirror of
https://github.com/lfnovo/open-notebook.git
synced 2026-04-30 20:39:55 +00:00
refactor objectmodel
This commit is contained in:
parent
f140a5e228
commit
c297dcb809
8 changed files with 186 additions and 68 deletions
|
|
@ -1,8 +1,22 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, cast
|
||||
from typing import (
|
||||
Any,
|
||||
ClassVar,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ValidationError,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from open_notebook.database.repository import (
|
||||
repo_create,
|
||||
|
|
@ -10,6 +24,7 @@ from open_notebook.database.repository import (
|
|||
repo_query,
|
||||
repo_relate,
|
||||
repo_update,
|
||||
repo_upsert,
|
||||
)
|
||||
from open_notebook.exceptions import (
|
||||
DatabaseOperationError,
|
||||
|
|
@ -204,24 +219,92 @@ class ObjectModel(BaseModel):
|
|||
|
||||
class RecordModel(BaseModel):
|
||||
record_id: ClassVar[str]
|
||||
auto_save: ClassVar[bool] = (
|
||||
False # Default to False, can be overridden in subclasses
|
||||
)
|
||||
_instances: ClassVar[Dict[str, "RecordModel"]] = {} # Store instances by record_id
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
arbitrary_types_allowed = True
|
||||
extra = "allow"
|
||||
from_attributes = True
|
||||
defer_build = True
|
||||
|
||||
def __new__(cls, **kwargs):
|
||||
# If an instance already exists for this record_id, return it
|
||||
if cls.record_id in cls._instances:
|
||||
instance = cls._instances[cls.record_id]
|
||||
# Update instance with any new kwargs if provided
|
||||
if kwargs:
|
||||
for key, value in kwargs.items():
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
|
||||
# If no instance exists, create a new one
|
||||
instance = super().__new__(cls)
|
||||
cls._instances[cls.record_id] = instance
|
||||
return instance
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.load()
|
||||
# Only initialize if this is a new instance
|
||||
if not hasattr(self, "_initialized"):
|
||||
object.__setattr__(self, "__dict__", {})
|
||||
# Load data from DB first
|
||||
result = repo_query(f"SELECT * FROM {self.record_id};")
|
||||
if result:
|
||||
db_data = result[0]
|
||||
else:
|
||||
# Initialize empty object with None for Optional fields
|
||||
db_data = {
|
||||
field_name: None
|
||||
for field_name, field_info in self.model_fields.items()
|
||||
if not str(field_info.annotation).startswith("typing.ClassVar")
|
||||
}
|
||||
|
||||
# Initialize with DB data and any overrides
|
||||
super().__init__(**{**db_data, **kwargs})
|
||||
object.__setattr__(self, "_initialized", True)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "RecordModel":
|
||||
"""Get or create the singleton instance"""
|
||||
return cls()
|
||||
|
||||
@model_validator(mode="after")
|
||||
def auto_save_validator(self):
|
||||
if self.__class__.auto_save:
|
||||
self.update()
|
||||
return self
|
||||
|
||||
def update(self):
|
||||
# Get all non-ClassVar fields and their values
|
||||
data = {
|
||||
field_name: getattr(self, field_name)
|
||||
for field_name, field_info in self.model_fields.items()
|
||||
if not str(field_info.annotation).startswith("typing.ClassVar")
|
||||
}
|
||||
|
||||
repo_upsert(self.record_id, data)
|
||||
|
||||
def load(self):
|
||||
result = repo_query(f"SELECT * FROM {self.record_id};")
|
||||
if result:
|
||||
result = result[0]
|
||||
else:
|
||||
repo_create(self.record_id, {})
|
||||
result = {}
|
||||
for key, value in result.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
for key, value in result[0].items():
|
||||
if hasattr(self, key):
|
||||
object.__setattr__(
|
||||
self, key, value
|
||||
) # Use object.__setattr__ to avoid triggering validation again
|
||||
|
||||
return self
|
||||
|
||||
def update(self, data):
|
||||
repo_update(self.record_id, data)
|
||||
return self.load()
|
||||
@classmethod
|
||||
def clear_instance(cls):
|
||||
"""Clear the singleton instance (useful for testing)"""
|
||||
if cls.record_id in cls._instances:
|
||||
del cls._instances[cls.record_id]
|
||||
|
||||
def patch(self, model_dict: dict):
|
||||
"""Update model attributes from dictionary and save"""
|
||||
for key, value in model_dict.items():
|
||||
setattr(self, key, value)
|
||||
self.update()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue