mirror of
https://github.com/eigent-ai/eigent.git
synced 2026-05-23 21:06:50 +00:00
Some checks are pending
Pre-commit / pre-commit (push) Waiting to run
CodeQL Advanced / Analyze (actions) (push) Waiting to run
CodeQL Advanced / Analyze (javascript-typescript) (push) Waiting to run
CodeQL Advanced / Analyze (python) (push) Waiting to run
Test / Run Python Tests (push) Waiting to run
193 lines
6.1 KiB
Python
193 lines
6.1 KiB
Python
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. =========
|
||
|
||
import logging
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from convert_case import snake_case
|
||
from fastapi_babel import _
|
||
from sqlalchemy import delete
|
||
from sqlalchemy.orm import declared_attr
|
||
from sqlalchemy.sql.base import ExecutableOption
|
||
from sqlalchemy.sql.expression import ColumnExpressionArgument
|
||
from sqlmodel import (
|
||
TIMESTAMP,
|
||
Field,
|
||
Session,
|
||
SQLModel,
|
||
col,
|
||
func,
|
||
select,
|
||
text,
|
||
)
|
||
|
||
from app.core import code
|
||
from app.core.database import engine
|
||
from app.shared.exception import UserException
|
||
|
||
logger = logging.getLogger("abstract_model")
|
||
|
||
|
||
class AbstractModel(SQLModel):
|
||
@declared_attr # type: ignore
|
||
def __tablename__(cls) -> str:
|
||
return snake_case(cls.__name__)
|
||
|
||
@classmethod
|
||
def by(
|
||
cls,
|
||
*whereclause: ColumnExpressionArgument[bool] | bool,
|
||
order_by: Any | None = None,
|
||
limit: int | None = None,
|
||
offset: int | None = None,
|
||
options: ExecutableOption | list[ExecutableOption] | None = None,
|
||
s: Session,
|
||
):
|
||
logger.debug(
|
||
"Executing query by conditions",
|
||
extra={
|
||
"model_class": cls.__name__,
|
||
"has_order_by": order_by is not None,
|
||
"limit": limit,
|
||
"offset": offset,
|
||
"has_options": options is not None,
|
||
},
|
||
)
|
||
stmt = select(cls).where(*whereclause)
|
||
if order_by is not None:
|
||
stmt = stmt.order_by(order_by)
|
||
if limit is not None:
|
||
stmt = stmt.limit(limit)
|
||
if offset is not None:
|
||
stmt = stmt.offset(offset)
|
||
if options is not None:
|
||
stmt = stmt.options(*(options if isinstance(options, list) else [options]))
|
||
return s.exec(stmt, execution_options={"prebuffer_rows": True})
|
||
|
||
@classmethod
|
||
def exists(
|
||
cls,
|
||
*whereclause: ColumnExpressionArgument[bool] | bool,
|
||
s: Session,
|
||
) -> bool:
|
||
logger.debug("Checking if record exists", extra={"model_class": cls.__name__})
|
||
res = s.exec(select(func.count("*")).where(*whereclause)).first()
|
||
result = res is not None and res > 0
|
||
logger.debug(
|
||
"Record existence check result", extra={"model_class": cls.__name__, "exists": result, "count": res}
|
||
)
|
||
return result
|
||
|
||
@classmethod
|
||
def count(
|
||
cls,
|
||
*whereclause: ColumnExpressionArgument[bool] | bool,
|
||
s: Session,
|
||
) -> int:
|
||
res = s.exec(select(func.count("*")).where(*whereclause)).first()
|
||
return res if res is not None else 0
|
||
|
||
@classmethod
|
||
def exists_must(
|
||
cls,
|
||
*whereclause: ColumnExpressionArgument[bool] | bool,
|
||
s: Session,
|
||
):
|
||
if not cls.exists(*whereclause, s=s):
|
||
raise UserException(code.not_found, _("There is no data that meets the conditions"))
|
||
|
||
@classmethod
|
||
def delete_by(
|
||
cls,
|
||
*whereclause: ColumnExpressionArgument[bool],
|
||
s: Session,
|
||
):
|
||
logger.info("Deleting records by conditions", extra={"model_class": cls.__name__})
|
||
stmt = delete(cls).where(*whereclause)
|
||
result = s.connection().execute(stmt)
|
||
s.commit()
|
||
logger.info("Records deleted", extra={"model_class": cls.__name__, "rows_affected": result.rowcount})
|
||
|
||
def save(self, s: Session | None = None):
|
||
model_id = getattr(self, "id", None)
|
||
is_new = model_id is None
|
||
logger.info(
|
||
"Saving model",
|
||
extra={"model_class": self.__class__.__name__, "model_id": model_id, "is_new_record": is_new},
|
||
)
|
||
|
||
if s is None:
|
||
with Session(engine, expire_on_commit=False) as s:
|
||
s.add(self)
|
||
s.commit()
|
||
else:
|
||
s.add(self)
|
||
s.commit()
|
||
|
||
logger.info(
|
||
"Model saved successfully",
|
||
extra={
|
||
"model_class": self.__class__.__name__,
|
||
"model_id": getattr(self, "id", None),
|
||
"was_new_record": is_new,
|
||
},
|
||
)
|
||
|
||
def delete(self, s: Session):
|
||
model_id = getattr(self, "id", None)
|
||
is_soft_delete = isinstance(self, DefaultTimes)
|
||
|
||
logger.info(
|
||
"Deleting model",
|
||
extra={"model_class": self.__class__.__name__, "model_id": model_id, "is_soft_delete": is_soft_delete},
|
||
)
|
||
|
||
if isinstance(self, DefaultTimes):
|
||
self.deleted_at = datetime.now()
|
||
self.save(s)
|
||
else:
|
||
s.delete(self)
|
||
s.commit()
|
||
|
||
logger.info(
|
||
"Model deleted successfully",
|
||
extra={"model_class": self.__class__.__name__, "model_id": model_id, "was_soft_delete": is_soft_delete},
|
||
)
|
||
|
||
def update_fields(self, update_dict: dict):
|
||
for k, v in update_dict.items():
|
||
setattr(self, k, v)
|
||
|
||
|
||
class DefaultTimes:
|
||
deleted_at: datetime | None = Field(default=None)
|
||
created_at: datetime | None = Field(
|
||
# 兼容mysql,如果只有数据库的保存的话,保存后,created_at为None,无法立即调用
|
||
default_factory=datetime.now,
|
||
sa_type=TIMESTAMP,
|
||
sa_column_kwargs={"server_default": text("CURRENT_TIMESTAMP")},
|
||
)
|
||
updated_at: datetime | None = Field(
|
||
default_factory=datetime.now,
|
||
sa_type=TIMESTAMP,
|
||
sa_column_kwargs={
|
||
"server_default": text("CURRENT_TIMESTAMP"),
|
||
"onupdate": func.now(),
|
||
},
|
||
)
|
||
|
||
@classmethod
|
||
def no_delete(cls):
|
||
return col(cls.deleted_at).is_(None)
|