添加对 peewee 的旧数据库的兼容层,初步重构插件的 database API
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from rich.traceback import install
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy import create_engine, event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy import inspect as sqlalchemy_inspect
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from typing import TYPE_CHECKING, Generator
|
||||
|
||||
@@ -131,3 +132,59 @@ def get_db() -> Generator[Session, None, None]:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
class _AtomicContext:
|
||||
def __init__(self) -> None:
|
||||
self._session: Session | None = None
|
||||
|
||||
def __enter__(self) -> Session:
|
||||
self._session = SessionLocal()
|
||||
self._session.begin()
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> None:
|
||||
if self._session is None:
|
||||
return
|
||||
try:
|
||||
if exc_type is None:
|
||||
self._session.commit()
|
||||
else:
|
||||
self._session.rollback()
|
||||
finally:
|
||||
self._session.close()
|
||||
|
||||
|
||||
class DatabaseCompat:
|
||||
"""兼容旧 db 调用接口(Peewee 风格),底层使用 SQLAlchemy。"""
|
||||
|
||||
def connect(self, reuse_if_open: bool = True) -> None:
|
||||
# SQLAlchemy 由 engine 按需管理连接,这里保留兼容入口。
|
||||
_ = reuse_if_open
|
||||
|
||||
def create_tables(self, models: list[type], safe: bool = True) -> None:
|
||||
_ = safe
|
||||
tables = [model.__table__ for model in models if hasattr(model, "__table__")]
|
||||
if not tables:
|
||||
return
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
SQLModel.metadata.create_all(engine, tables=tables)
|
||||
|
||||
def atomic(self) -> _AtomicContext:
|
||||
return _AtomicContext()
|
||||
|
||||
def execute_sql(self, sql: str):
|
||||
with engine.connect() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
conn.commit()
|
||||
return result
|
||||
|
||||
def table_exists(self, model: type) -> bool:
|
||||
if not hasattr(model, "__tablename__"):
|
||||
return False
|
||||
inspector = sqlalchemy_inspect(engine)
|
||||
return inspector.has_table(model.__tablename__)
|
||||
|
||||
|
||||
db = DatabaseCompat()
|
||||
|
||||
Reference in New Issue
Block a user