Files
mai-bot/src/common/database/database.py
DrSmoothl c2c992ff01 feat(database-migrations): implement database migration manager and related components
- Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps.
- Introduce models for migration state, execution context, and migration steps.
- Implement MigrationPlanner to generate migration plans based on current and target versions.
- Create MigrationRegistry for registering and managing migration steps.
- Develop SchemaVersionResolver to determine the current database schema version.
- Add SQLiteSchemaInspector for inspecting SQLite database structures.
- Implement progress reporting tools using rich for visualizing migration progress.
- Introduce SQLiteUserVersionStore for managing schema version storage in SQLite.
2026-03-31 09:16:25 +08:00

205 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from contextlib import contextmanager
from pathlib import Path
from typing import ContextManager, Generator, TYPE_CHECKING
from rich.traceback import install
from sqlalchemy import event, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, Session, create_engine
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.logger import get_logger
if TYPE_CHECKING:
from sqlite3 import Connection as SQLite3Connection
install(extra_lines=3)
logger = get_logger("database")
# 定义数据库文件路径
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
_DB_DIR = ROOT_PATH / "data"
_DB_FILE = _DB_DIR / "MaiBot.db"
# 确保数据库目录存在
_DB_DIR.mkdir(parents=True, exist_ok=True)
DATABASE_URL = f"sqlite:///{_DB_FILE}"
@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection: "SQLite3Connection", connection_record):
"""
为每个新的数据库连接设置 SQLite PRAGMA。
"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA cache_size=-64000") # 负值表示KB,64000KB = 64MB
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA busy_timeout=1000") # 1秒超时
cursor.close()
# 连接数据库
engine = create_engine(
DATABASE_URL,
echo=False,
connect_args={"check_same_thread": False},
pool_pre_ping=True,
)
# 创建会话工厂(使用 sqlmodel.Session
SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
class_=Session,
)
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
_db_initialized = False
def _migrate_action_records_to_tool_records() -> None:
"""将旧的 ``action_records`` 历史数据迁移到 ``tool_records``。"""
migration_sql = text(
"""
INSERT INTO tool_records (
tool_id,
timestamp,
session_id,
tool_name,
tool_reasoning,
tool_data,
tool_builtin_prompt,
tool_display_prompt
)
SELECT
action_id,
timestamp,
session_id,
action_name,
action_reasoning,
action_data,
action_builtin_prompt,
action_display_prompt
FROM action_records
WHERE NOT EXISTS (
SELECT 1
FROM tool_records
WHERE tool_records.tool_id = action_records.action_id
)
"""
)
with engine.begin() as connection:
connection.execute(migration_sql)
def initialize_database() -> None:
"""初始化数据库连接、结构与启动期迁移。
当前初始化流程遵循以下顺序:
1. 确保数据库目录存在;
2. 加载 SQLModel 模型定义;
3. 执行已注册的启动期迁移;
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
5. 执行项目现有的轻量数据补迁移逻辑。
"""
global _db_initialized
if _db_initialized:
return
_DB_DIR.mkdir(parents=True, exist_ok=True)
import src.common.database.database_model # noqa: F401
migration_state = _migration_bootstrapper.prepare_database()
logger.info(
"数据库迁移准备完成,"
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
)
SQLModel.metadata.create_all(engine)
_migrate_action_records_to_tool_records()
_migration_bootstrapper.finalize_database(migration_state)
_db_initialized = True
@contextmanager
def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
"""
获取数据库会话的上下文管理器 (推荐使用,自动提交)。
Examples:
----
.. code-block:: python
# 方式1: 自动提交 (推荐 - 默认行为)
with get_db_session() as session:
user = User(name="张三", age=25)
session.add(user)
# 退出时自动 commit,无需手动调用
# 方式2: 手动控制事务 (高级用法)
with get_db_session(auto_commit=False) as session:
user1 = User(name="张三", age=25)
user2 = User(name="李四", age=30)
session.add_all([user1, user2])
session.commit() # 手动提交
Args:
auto_commit (bool): 是否在退出上下文时自动提交(默认: True
Yields:
Session: SQLAlchemy 数据库会话
注意:
- 会话会在退出上下文时自动关闭
- 如果发生异常,会自动回滚事务
- auto_commit=True 时,成功执行完会自动提交
- auto_commit=False 时,需要手动调用 session.commit()
"""
initialize_database()
session = SessionLocal()
try:
yield session
if auto_commit:
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def get_db_session_manual() -> ContextManager[Session]:
"""获取数据库会话的上下文管理器 (手动提交模式)。
Returns:
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
"""
return get_db_session(auto_commit=False)
def get_db() -> Generator[Session, None, None]:
"""
获取数据库会话的生成器函数。
适用于依赖注入场景(如 FastAPI)。
使用示例 (FastAPI):
----
.. code-block:: python
@app.get("/users/{user_id}")
def read_user(user_id: int, db: Session = Depends(get_db)):
return db.get(User, user_id)
Yields:
Session: SQLAlchemy 数据库会话
"""
initialize_database()
session = SessionLocal()
try:
yield session
finally:
session.close()