mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2025-09-06 12:40:02 +00:00
128 lines
4 KiB
Python
128 lines
4 KiB
Python
#!/usr/bin/env python
|
|
# coding=utf-8
|
|
'''
|
|
Description :
|
|
Author : chenxl
|
|
Date : 2024-06-12 09:12:58
|
|
Version : 1.0.0
|
|
LastEditors : chenxl
|
|
LastEditTime : 2024-07-27 01:56:04
|
|
'''
|
|
|
|
from urllib.parse import urlparse
|
|
import os
|
|
from contextlib import contextmanager
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker, declarative_base
|
|
|
|
from ktransformers.server.config.config import Config
|
|
from ktransformers.server.config.singleton import Singleton
|
|
from ktransformers.server.config.log import logger
|
|
from ktransformers.server.exceptions import db_exception
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class SQLUtil(metaclass=Singleton):
|
|
"""
|
|
database connections init and management
|
|
"""
|
|
sqlalchemy_engine = None
|
|
session_local = None
|
|
|
|
def __init__(self) -> None:
|
|
self.cfg: Config = Config()
|
|
if not self.sqlalchemy_engine:
|
|
SQLUtil.init_engine(self.cfg)
|
|
|
|
@contextmanager
|
|
def get_db(self):
|
|
"""
|
|
After you finish using the session, it's crucial to close it.
|
|
"""
|
|
if not SQLUtil.sqlalchemy_engine:
|
|
SQLUtil.init_engine(self.cfg)
|
|
session = self.session_local() # type: ignore pylint: disable=not-callable
|
|
try:
|
|
yield session
|
|
finally:
|
|
session.close()
|
|
|
|
@staticmethod
|
|
def init_engine(cfg: Config):
|
|
"""
|
|
initial engine and session maker Factory
|
|
"""
|
|
pool_size = cfg.db_pool_size
|
|
if SQLUtil.sqlalchemy_engine is None:
|
|
if cfg.db_type == "sqllite":
|
|
db_url = SQLUtil.create_sqllite_url(cfg)
|
|
else:
|
|
logger.error("Unsupported database type %s", cfg.db_type)
|
|
exit(-1)
|
|
SQLUtil.sqlalchemy_engine = create_engine(
|
|
db_url, connect_args={"check_same_thread": False}, pool_size=pool_size)
|
|
SQLUtil.session_local = sessionmaker(
|
|
autocommit=False, autoflush=False, bind=SQLUtil.sqlalchemy_engine)
|
|
|
|
@staticmethod
|
|
def create_sqllite_url(cfg):
|
|
"""
|
|
create and validate SQLLite url
|
|
"""
|
|
path: str = cfg.db_host
|
|
database: str = cfg.db_database
|
|
absolute_path: str = os.path.join(path, database)
|
|
url = 'sqlite:///' + absolute_path
|
|
try:
|
|
result = urlparse(url)
|
|
if all([result.scheme, result.path, result.scheme == 'sqlite']):
|
|
return url
|
|
else:
|
|
logger.error("invalid sqllite url: %s", url)
|
|
exit(-1)
|
|
except ValueError:
|
|
logger.error("invalid sqllite url: %s", url)
|
|
exit(-1)
|
|
|
|
def db_add_commit_refresh(self, session: Session, what):
|
|
"""
|
|
add data to database
|
|
"""
|
|
try:
|
|
session.add(what)
|
|
session.commit()
|
|
session.refresh(what)
|
|
except Exception as e:
|
|
logger.exception("db commit error with data %s", str(what.__dict__))
|
|
ex = db_exception()
|
|
ex.detail = str(e)
|
|
session.rollback()
|
|
raise ex from e
|
|
|
|
def db_merge_commit(self, session: Session, what):
|
|
try:
|
|
session.merge(what)
|
|
session.commit()
|
|
except Exception as e:
|
|
ex = db_exception()
|
|
ex.detail = str(e)
|
|
logger.exception("db merge commit error with data %s", str(what.__dict__))
|
|
session.rollback()
|
|
raise ex from e
|
|
|
|
def db_update_commit_refresh(self, session: Session, existing, what):
|
|
what = what.model_dump(mode="json")
|
|
try:
|
|
for key in what.keys():
|
|
if what[key] is not None: # 检查b中的字段是否为None
|
|
setattr(existing, key, what[key]) # 更新a的字段
|
|
session.commit()
|
|
session.refresh(existing)
|
|
except Exception as e:
|
|
ex = db_exception()
|
|
ex.detail = str(e)
|
|
logger.exception("db update commit refresh error with data %s", str(what.__dict__))
|
|
session.rollback()
|
|
raise ex from e
|