- Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps. - Introduce models for migration state, execution context, and migration steps. - Implement MigrationPlanner to generate migration plans based on current and target versions. - Create MigrationRegistry for registering and managing migration steps. - Develop SchemaVersionResolver to determine the current database schema version. - Add SQLiteSchemaInspector for inspecting SQLite database structures. - Implement progress reporting tools using rich for visualizing migration progress. - Introduce SQLiteUserVersionStore for managing schema version storage in SQLite.
205 lines
5.9 KiB
Python
205 lines
5.9 KiB
Python
from contextlib import contextmanager
|
||
from pathlib import Path
|
||
from typing import ContextManager, Generator, TYPE_CHECKING
|
||
|
||
from rich.traceback import install
|
||
from sqlalchemy import event, text
|
||
from sqlalchemy.engine import Engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from sqlmodel import SQLModel, Session, create_engine
|
||
|
||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||
from src.common.logger import get_logger
|
||
|
||
if TYPE_CHECKING:
|
||
from sqlite3 import Connection as SQLite3Connection
|
||
|
||
install(extra_lines=3)
|
||
|
||
logger = get_logger("database")
|
||
|
||
|
||
# 定义数据库文件路径
|
||
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||
_DB_DIR = ROOT_PATH / "data"
|
||
_DB_FILE = _DB_DIR / "MaiBot.db"
|
||
|
||
# 确保数据库目录存在
|
||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||
DATABASE_URL = f"sqlite:///{_DB_FILE}"
|
||
|
||
|
||
@event.listens_for(Engine, "connect")
|
||
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
|
||
"""
|
||
为每个新的数据库连接设置 SQLite PRAGMA。
|
||
"""
|
||
cursor = dbapi_connection.cursor()
|
||
cursor.execute("PRAGMA journal_mode=WAL")
|
||
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
|
||
cursor.execute("PRAGMA foreign_keys=ON")
|
||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
|
||
cursor.close()
|
||
|
||
|
||
# 连接数据库
|
||
engine = create_engine(
|
||
DATABASE_URL,
|
||
echo=False,
|
||
connect_args={"check_same_thread": False},
|
||
pool_pre_ping=True,
|
||
)
|
||
|
||
# 创建会话工厂(使用 sqlmodel.Session)
|
||
SessionLocal = sessionmaker(
|
||
autocommit=False,
|
||
autoflush=False,
|
||
bind=engine,
|
||
class_=Session,
|
||
)
|
||
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
||
|
||
_db_initialized = False
|
||
|
||
|
||
def _migrate_action_records_to_tool_records() -> None:
|
||
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
|
||
migration_sql = text(
|
||
"""
|
||
INSERT INTO tool_records (
|
||
tool_id,
|
||
timestamp,
|
||
session_id,
|
||
tool_name,
|
||
tool_reasoning,
|
||
tool_data,
|
||
tool_builtin_prompt,
|
||
tool_display_prompt
|
||
)
|
||
SELECT
|
||
action_id,
|
||
timestamp,
|
||
session_id,
|
||
action_name,
|
||
action_reasoning,
|
||
action_data,
|
||
action_builtin_prompt,
|
||
action_display_prompt
|
||
FROM action_records
|
||
WHERE NOT EXISTS (
|
||
SELECT 1
|
||
FROM tool_records
|
||
WHERE tool_records.tool_id = action_records.action_id
|
||
)
|
||
"""
|
||
)
|
||
with engine.begin() as connection:
|
||
connection.execute(migration_sql)
|
||
|
||
|
||
def initialize_database() -> None:
|
||
"""初始化数据库连接、结构与启动期迁移。
|
||
|
||
当前初始化流程遵循以下顺序:
|
||
1. 确保数据库目录存在;
|
||
2. 加载 SQLModel 模型定义;
|
||
3. 执行已注册的启动期迁移;
|
||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
||
5. 执行项目现有的轻量数据补迁移逻辑。
|
||
"""
|
||
global _db_initialized
|
||
if _db_initialized:
|
||
return
|
||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||
import src.common.database.database_model # noqa: F401
|
||
|
||
migration_state = _migration_bootstrapper.prepare_database()
|
||
logger.info(
|
||
"数据库迁移准备完成,"
|
||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||
)
|
||
SQLModel.metadata.create_all(engine)
|
||
_migrate_action_records_to_tool_records()
|
||
_migration_bootstrapper.finalize_database(migration_state)
|
||
_db_initialized = True
|
||
|
||
|
||
@contextmanager
|
||
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||
"""
|
||
获取数据库会话的上下文管理器 (推荐使用,自动提交)。
|
||
|
||
Examples:
|
||
----
|
||
.. code-block:: python
|
||
# 方式1: 自动提交 (推荐 - 默认行为)
|
||
with get_db_session() as session:
|
||
user = User(name="张三", age=25)
|
||
session.add(user)
|
||
# 退出时自动 commit,无需手动调用
|
||
|
||
# 方式2: 手动控制事务 (高级用法)
|
||
with get_db_session(auto_commit=False) as session:
|
||
user1 = User(name="张三", age=25)
|
||
user2 = User(name="李四", age=30)
|
||
session.add_all([user1, user2])
|
||
session.commit() # 手动提交
|
||
|
||
Args:
|
||
auto_commit (bool): 是否在退出上下文时自动提交(默认: True)。
|
||
|
||
Yields:
|
||
Session: SQLAlchemy 数据库会话
|
||
|
||
注意:
|
||
- 会话会在退出上下文时自动关闭
|
||
- 如果发生异常,会自动回滚事务
|
||
- auto_commit=True 时,成功执行完会自动提交
|
||
- auto_commit=False 时,需要手动调用 session.commit()
|
||
"""
|
||
initialize_database()
|
||
session = SessionLocal()
|
||
try:
|
||
yield session
|
||
if auto_commit:
|
||
session.commit()
|
||
except Exception:
|
||
session.rollback()
|
||
raise
|
||
finally:
|
||
session.close()
|
||
|
||
|
||
def get_db_session_manual() -> ContextManager[Session]:
|
||
"""获取数据库会话的上下文管理器 (手动提交模式)。
|
||
|
||
Returns:
|
||
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
|
||
"""
|
||
return get_db_session(auto_commit=False)
|
||
|
||
|
||
def get_db() -> Generator[Session, None, None]:
|
||
"""
|
||
获取数据库会话的生成器函数。
|
||
|
||
适用于依赖注入场景(如 FastAPI)。
|
||
|
||
使用示例 (FastAPI):
|
||
----
|
||
.. code-block:: python
|
||
@app.get("/users/{user_id}")
|
||
def read_user(user_id: int, db: Session = Depends(get_db)):
|
||
return db.get(User, user_id)
|
||
|
||
Yields:
|
||
Session: SQLAlchemy 数据库会话
|
||
"""
|
||
initialize_database()
|
||
session = SessionLocal()
|
||
try:
|
||
yield session
|
||
finally:
|
||
session.close()
|