feat(database-migrations): implement database migration manager and related components

- Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps.
- Introduce models for migration state, execution context, and migration steps.
- Implement MigrationPlanner to generate migration plans based on current and target versions.
- Create MigrationRegistry for registering and managing migration steps.
- Develop SchemaVersionResolver to determine the current database schema version.
- Add SQLiteSchemaInspector for inspecting SQLite database structures.
- Implement progress reporting tools using rich for visualizing migration progress.
- Introduce SQLiteUserVersionStore for managing schema version storage in SQLite.
This commit is contained in:
DrSmoothl
2026-03-31 09:16:25 +08:00
parent ea4cea39f2
commit c2c992ff01
15 changed files with 4025 additions and 4 deletions

View File

@@ -0,0 +1,912 @@
"""数据库迁移基础设施测试。"""
from pathlib import Path
from typing import List, Optional, Tuple
from sqlalchemy import text
from sqlalchemy.engine import Connection, Engine
from sqlmodel import SQLModel, create_engine
import json
import msgpack
import pytest
from src.common.database import database as database_module
from src.common.database.migrations import (
BaseSchemaVersionDetector,
BaseMigrationProgressReporter,
DatabaseSchemaSnapshot,
DatabaseMigrationBootstrapper,
DatabaseMigrationState,
DatabaseMigrationManager,
EMPTY_SCHEMA_VERSION,
LATEST_SCHEMA_VERSION,
LEGACY_V1_SCHEMA_VERSION,
MigrationExecutionContext,
MigrationPlan,
MigrationRegistry,
MigrationStep,
ResolvedSchemaVersion,
SchemaVersionResolver,
SchemaVersionSource,
SQLiteSchemaInspector,
SQLiteUserVersionStore,
build_default_migration_registry,
build_default_schema_version_resolver,
create_database_migration_bootstrapper,
)
class FixedVersionDetector(BaseSchemaVersionDetector):
"""测试用固定版本探测器。"""
@property
def name(self) -> str:
"""返回测试探测器名称。
Returns:
str: 探测器名称。
"""
return "fixed_version_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""根据测试表是否存在返回固定版本。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
"""
if snapshot.has_table("legacy_records"):
return 2
return None
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
"""测试用迁移进度上报器。"""
def __init__(self) -> None:
"""初始化测试用进度上报器。"""
self.events: List[Tuple[str, Optional[int], Optional[str], Optional[str]]] = []
def open(self) -> None:
"""记录打开事件。"""
self.events.append(("open", None, None, None))
def close(self) -> None:
"""记录关闭事件。"""
self.events.append(("close", None, None, None))
def start(
self,
total: int,
description: str = "总迁移进度",
unit_name: str = "",
) -> None:
"""记录启动事件。
Args:
total: 任务总数。
description: 任务描述。
unit_name: 进度单位名称。
"""
self.events.append(("start", total, description, unit_name))
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
"""记录推进事件。
Args:
advance: 推进步数。
item_name: 当前完成的项目名称。
"""
self.events.append(("advance", advance, item_name, None))
def _create_sqlite_engine(database_file: Path) -> Engine:
"""创建测试用 SQLite 引擎。
Args:
database_file: 测试数据库文件路径。
Returns:
Engine: SQLite 引擎实例。
"""
return create_engine(
f"sqlite:///{database_file}",
echo=False,
connect_args={"check_same_thread": False},
)
def _create_current_schema(connection: Connection) -> None:
"""创建当前最新版本的数据库结构。
Args:
connection: 当前数据库连接。
"""
import src.common.database.database_model # noqa: F401
SQLModel.metadata.create_all(connection)
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
Args:
connection: 当前数据库连接。
"""
connection.execute(
text(
"""
CREATE TABLE chat_streams (
id INTEGER PRIMARY KEY,
stream_id TEXT NOT NULL,
create_time REAL NOT NULL,
last_active_time REAL NOT NULL,
platform TEXT NOT NULL,
user_id TEXT,
group_id TEXT,
group_name TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE messages (
id INTEGER PRIMARY KEY,
message_id TEXT NOT NULL,
time REAL NOT NULL,
chat_id TEXT NOT NULL,
chat_info_platform TEXT,
user_id TEXT,
user_nickname TEXT,
chat_info_group_id TEXT,
chat_info_group_name TEXT,
is_mentioned INTEGER,
is_at INTEGER,
processed_plain_text TEXT,
display_message TEXT,
is_emoji INTEGER,
is_picid INTEGER,
is_command INTEGER,
is_notify INTEGER,
additional_config TEXT,
priority_mode TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE action_records (
id INTEGER PRIMARY KEY,
action_id TEXT NOT NULL,
time REAL NOT NULL,
action_reasoning TEXT,
action_name TEXT NOT NULL,
action_data TEXT,
action_prompt_display TEXT,
chat_id TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE expression (
id INTEGER PRIMARY KEY,
situation TEXT NOT NULL,
style TEXT NOT NULL,
content_list TEXT,
count INTEGER,
last_active_time REAL NOT NULL,
chat_id TEXT,
create_date REAL,
checked INTEGER,
rejected INTEGER,
modified_by TEXT
)
"""
)
)
connection.execute(
text(
"""
CREATE TABLE jargon (
id INTEGER PRIMARY KEY,
content TEXT NOT NULL,
raw_content TEXT,
meaning TEXT,
chat_id TEXT,
is_global INTEGER,
count INTEGER,
is_jargon INTEGER,
last_inference_count INTEGER,
is_complete INTEGER,
inference_with_context TEXT,
inference_content_only TEXT
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO chat_streams (
id,
stream_id,
create_time,
last_active_time,
platform,
user_id,
group_id,
group_name
) VALUES (
1,
'session-1',
1710000000.0,
1710000300.0,
'qq',
'user-1',
'group-1',
'测试群'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO messages (
id,
message_id,
time,
chat_id,
chat_info_platform,
user_id,
user_nickname,
chat_info_group_id,
chat_info_group_name,
is_mentioned,
is_at,
processed_plain_text,
display_message,
is_emoji,
is_picid,
is_command,
is_notify,
additional_config,
priority_mode
) VALUES (
1,
'msg-1',
1710000010.0,
'session-1',
'qq',
'user-1',
'测试用户',
'group-1',
'测试群',
1,
0,
'你好',
'你好呀',
0,
1,
0,
1,
'{"source":"legacy"}',
'high'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO action_records (
id,
action_id,
time,
action_reasoning,
action_name,
action_data,
action_prompt_display,
chat_id
) VALUES (
1,
'action-1',
1710000020.0,
'需要调用工具',
'search',
'{"query":"MaiBot"}',
'执行搜索',
'session-1'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO expression (
id,
situation,
style,
content_list,
count,
last_active_time,
chat_id,
create_date,
checked,
rejected,
modified_by
) VALUES (
1,
'打招呼',
'可爱',
'["你好呀","早上好"]',
3,
1710000030.0,
'session-1',
1710000040.0,
1,
0,
'ai'
)
"""
)
)
connection.execute(
text(
"""
INSERT INTO jargon (
id,
content,
raw_content,
meaning,
chat_id,
is_global,
count,
is_jargon,
last_inference_count,
is_complete,
inference_with_context,
inference_content_only
) VALUES (
1,
'上分',
'["上分"]',
'提高排名',
'session-1',
0,
5,
1,
2,
1,
'{"guess":"context"}',
'{"guess":"content"}'
)
"""
)
)
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
"""应支持读取与写入 SQLite ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "version_store.db")
version_store = SQLiteUserVersionStore()
with engine.begin() as connection:
assert version_store.read_version(connection) == 0
version_store.write_version(connection, 7)
with engine.connect() as connection:
assert version_store.read_version(connection) == 7
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
"""应能提取 SQLite 数据库的表与列结构。"""
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
inspector = SQLiteSchemaInspector()
with engine.begin() as connection:
connection.execute(
text(
"""
CREATE TABLE legacy_records (
id INTEGER PRIMARY KEY,
payload TEXT NOT NULL,
created_at TEXT
)
"""
)
)
with engine.connect() as connection:
snapshot = inspector.inspect(connection)
assert snapshot.has_table("legacy_records")
assert snapshot.has_column("legacy_records", "payload")
assert not snapshot.has_column("legacy_records", "missing_column")
table_schema = snapshot.get_table("legacy_records")
assert table_schema is not None
assert table_schema.column_names() == ["created_at", "id", "payload"]
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
"""空数据库应被解析为版本 ``0``。"""
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
resolver = SchemaVersionResolver()
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 0
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
assert resolved_version.snapshot is not None
assert resolved_version.snapshot.is_empty()
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
with engine.begin() as connection:
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == 2
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "fixed_version_detector"
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
engine = _create_sqlite_engine(tmp_path / "manager.db")
executed_steps: List[str] = []
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 0 -> 1。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 1 -> 2。
Args:
context: 当前迁移步骤执行上下文。
"""
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=0,
version_to=1,
name="create_sample_records",
description="创建示例表。",
handler=migrate_0_to_1,
),
MigrationStep(
version_from=1,
version_to=2,
name="add_sample_email",
description="为示例表增加邮箱字段。",
handler=migrate_1_to_2,
),
]
)
manager = DatabaseMigrationManager(engine=engine, registry=registry)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 2
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
with engine.connect() as connection:
version_store = SQLiteUserVersionStore()
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = version_store.read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("sample_records")
assert snapshot.has_column("sample_records", "email")
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
"""迁移编排器应支持通过上下文上报步骤进度。"""
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
Args:
context: 当前迁移步骤执行上下文。
"""
context.start_progress(total=3, description="总迁移进度", unit_name="")
context.advance_progress(item_name="chat_sessions")
context.advance_progress(item_name="mai_messages")
context.advance_progress(item_name="tool_records")
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
with engine.begin() as connection:
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="progress_step",
description="测试进度上报。",
handler=migrate_1_to_2,
)
]
)
manager = DatabaseMigrationManager(
engine=engine,
registry=registry,
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate()
assert migration_plan.step_count() == 1
assert len(reporter_instances) == 1
assert reporter_instances[0].events == [
("open", None, None, None),
("start", 3, "总迁移进度", ""),
("advance", 1, "chat_sessions", None),
("advance", 1, "mai_messages", None),
("advance", 1, "tool_records", None),
("close", None, None, None),
]
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_current_schema(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LATEST_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "latest_schema_detector"
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
resolver = build_default_schema_version_resolver()
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
with engine.connect() as connection:
resolved_version = resolver.resolve(connection)
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
assert resolved_version.source == SchemaVersionSource.DETECTOR
assert resolved_version.detector_name == "legacy_v1_schema_detector"
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_current_schema(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
"""空库在建表完成后应回写最新 ``user_version``。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
bootstrapper = create_database_migration_bootstrapper(engine)
migration_state = bootstrapper.prepare_database()
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
assert migration_state.target_version == LATEST_SCHEMA_VERSION
with engine.begin() as connection:
_create_current_schema(connection)
bootstrapper.finalize_database(migration_state)
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == LATEST_SCHEMA_VERSION
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
execution_marks: List[str] = []
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
"""测试桥接器迁移步骤 ``1 -> 2``。
Args:
context: 当前迁移步骤执行上下文。
"""
execution_marks.append(f"step={context.step_name},index={context.step_index}")
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
with engine.begin() as connection:
connection.execute(
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
)
SQLiteUserVersionStore().write_version(connection, 1)
registry = MigrationRegistry(
steps=[
MigrationStep(
version_from=1,
version_to=2,
name="bootstrap_add_email",
description="为桥接器测试表增加邮箱字段。",
handler=migrate_1_to_2,
)
]
)
bootstrapper = DatabaseMigrationBootstrapper(
manager=DatabaseMigrationManager(engine=engine, registry=registry),
latest_schema_version=2,
)
migration_state = bootstrapper.prepare_database()
assert migration_state.resolved_version.version == 2
assert migration_state.target_version == 2
assert execution_marks == ["step=bootstrap_add_email,index=1"]
with engine.connect() as connection:
snapshot = SQLiteSchemaInspector().inspect(connection)
recorded_version = SQLiteUserVersionStore().read_version(connection)
assert recorded_version == 2
assert snapshot.has_table("bootstrap_records")
assert snapshot.has_column("bootstrap_records", "email")
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
bootstrapper = create_database_migration_bootstrapper(engine)
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
migration_state = bootstrapper.prepare_database()
bootstrapper.finalize_database(migration_state)
assert not migration_state.requires_migration()
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
with engine.connect() as connection:
recorded_version = SQLiteUserVersionStore().read_version(connection)
snapshot = SQLiteSchemaInspector().inspect(connection)
message_row = connection.execute(
text(
"""
SELECT session_id, processed_plain_text, additional_config, raw_content
FROM mai_messages
WHERE message_id = 'msg-1'
"""
)
).mappings().one()
action_row = connection.execute(
text(
"""
SELECT session_id, action_name, action_display_prompt
FROM action_records
WHERE action_id = 'action-1'
"""
)
).mappings().one()
tool_row = connection.execute(
text(
"""
SELECT session_id, tool_name, tool_display_prompt
FROM tool_records
WHERE tool_id = 'action-1'
"""
)
).mappings().one()
expression_row = connection.execute(
text(
"""
SELECT session_id, content_list, modified_by
FROM expressions
WHERE id = 1
"""
)
).mappings().one()
jargon_row = connection.execute(
text(
"""
SELECT session_id_dict, raw_content, inference_with_content_only
FROM jargons
WHERE id = 1
"""
)
).mappings().one()
assert recorded_version == LATEST_SCHEMA_VERSION
assert snapshot.has_table("__legacy_v1_messages")
assert snapshot.has_table("chat_sessions")
assert snapshot.has_table("mai_messages")
assert snapshot.has_table("tool_records")
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
additional_config = json.loads(message_row["additional_config"])
expression_content_list = json.loads(expression_row["content_list"])
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
jargon_raw_content = json.loads(jargon_row["raw_content"])
assert message_row["session_id"] == "session-1"
assert message_row["processed_plain_text"] == "你好"
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
assert additional_config == {"priority_mode": "high", "source": "legacy"}
assert action_row["session_id"] == "session-1"
assert action_row["action_name"] == "search"
assert action_row["action_display_prompt"] == "执行搜索"
assert tool_row["session_id"] == "session-1"
assert tool_row["tool_name"] == "search"
assert tool_row["tool_display_prompt"] == "执行搜索"
assert expression_row["session_id"] == "session-1"
assert expression_row["modified_by"] == "AI"
assert expression_content_list == ["你好呀", "早上好"]
assert jargon_session_id_dict == {"session-1": 5}
assert jargon_raw_content == ["上分"]
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
"""旧版迁移步骤应按目标表数量推进总进度。"""
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
reporter_instances: List[FakeMigrationProgressReporter] = []
def _build_reporter() -> BaseMigrationProgressReporter:
"""构建测试用进度上报器。
Returns:
BaseMigrationProgressReporter: 测试用进度上报器实例。
"""
reporter = FakeMigrationProgressReporter()
reporter_instances.append(reporter)
return reporter
with engine.begin() as connection:
_create_legacy_v1_schema_with_sample_data(connection)
manager = DatabaseMigrationManager(
engine=engine,
registry=build_default_migration_registry(),
resolver=build_default_schema_version_resolver(),
progress_reporter_factory=_build_reporter,
)
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
assert migration_plan.step_count() == 1
assert len(reporter_instances) == 1
reporter_events = reporter_instances[0].events
assert reporter_events[0] == ("open", None, None, None)
assert reporter_events[1] == ("start", 12, "总迁移进度", "")
assert reporter_events[-1] == ("close", None, None, None)
assert reporter_events.count(("advance", 1, "chat_sessions", None)) == 1
assert reporter_events.count(("advance", 1, "thinking_questions", None)) == 1
assert len([event for event in reporter_events if event[0] == "advance"]) == 12
def test_initialize_database_calls_bootstrapper_before_create_all(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
) -> None:
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
call_order: List[str] = []
def _fake_prepare_database() -> DatabaseMigrationState:
"""返回测试用迁移状态。
Returns:
DatabaseMigrationState: 不包含迁移步骤的测试状态。
"""
call_order.append("prepare_database")
return DatabaseMigrationState(
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
target_version=LATEST_SCHEMA_VERSION,
plan=MigrationPlan(
current_version=EMPTY_SCHEMA_VERSION,
target_version=LATEST_SCHEMA_VERSION,
steps=[],
),
)
def _fake_create_all(bind) -> None:
"""记录建表调用。
Args:
bind: 传入的数据库绑定对象。
"""
del bind
call_order.append("create_all")
def _fake_migrate_action_records() -> None:
"""记录轻量补迁移调用。"""
call_order.append("migrate_action_records")
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
"""记录迁移收尾调用。
Args:
migration_state: 当前数据库迁移状态。
"""
del migration_state
call_order.append("finalize_database")
monkeypatch.setattr(database_module, "_db_initialized", False)
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
database_module.initialize_database()
assert call_order == [
"prepare_database",
"create_all",
"migrate_action_records",
"finalize_database",
]

View File

@@ -1,18 +1,23 @@
from rich.traceback import install
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, TYPE_CHECKING
from typing import ContextManager, Generator, TYPE_CHECKING
from rich.traceback import install
from sqlalchemy import event, text
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, Session, create_engine
from src.common.database.migrations import create_database_migration_bootstrapper
from src.common.logger import get_logger
if TYPE_CHECKING:
from sqlite3 import Connection as SQLite3Connection
install(extra_lines=3)
logger = get_logger("database")
# 定义数据库文件路径
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
@@ -53,6 +58,7 @@ SessionLocal = sessionmaker(
bind=engine,
class_=Session,
)
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
_db_initialized = False
@@ -93,14 +99,29 @@ def _migrate_action_records_to_tool_records() -> None:
def initialize_database() -> None:
"""初始化数据库连接、结构与启动期迁移。
当前初始化流程遵循以下顺序:
1. 确保数据库目录存在;
2. 加载 SQLModel 模型定义;
3. 执行已注册的启动期迁移;
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
5. 执行项目现有的轻量数据补迁移逻辑。
"""
global _db_initialized
if _db_initialized:
return
_DB_DIR.mkdir(parents=True, exist_ok=True)
import src.common.database.database_model # noqa: F401
migration_state = _migration_bootstrapper.prepare_database()
logger.info(
"数据库迁移准备完成,"
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
)
SQLModel.metadata.create_all(engine)
_migrate_action_records_to_tool_records()
_migration_bootstrapper.finalize_database(migration_state)
_db_initialized = True
@@ -150,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
session.close()
def get_db_session_manual():
"""获取数据库会话的上下文管理器 (手动提交模式)。"""
def get_db_session_manual() -> ContextManager[Session]:
"""获取数据库会话的上下文管理器 (手动提交模式)。
Returns:
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
"""
return get_db_session(auto_commit=False)

View File

@@ -0,0 +1,79 @@
"""数据库迁移基础设施导出模块。"""
from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper
from .builtin import (
EMPTY_SCHEMA_VERSION,
LATEST_SCHEMA_VERSION,
LEGACY_V1_SCHEMA_VERSION,
build_default_migration_registry,
build_default_schema_version_resolver,
)
from .exceptions import (
DatabaseMigrationConfigurationError,
DatabaseMigrationError,
DatabaseMigrationExecutionError,
DatabaseMigrationPlanningError,
DatabaseMigrationVersionError,
MissingMigrationStepError,
UnrecognizedDatabaseSchemaError,
UnsupportedMigrationDirectionError,
)
from .manager import DatabaseMigrationManager
from .models import (
ColumnSchema,
DatabaseMigrationState,
DatabaseSchemaSnapshot,
MigrationExecutionContext,
MigrationPlan,
MigrationStep,
ResolvedSchemaVersion,
SchemaVersionSource,
TableSchema,
)
from .planner import MigrationPlanner
from .progress import (
BaseMigrationProgressReporter,
RichMigrationProgressReporter,
create_rich_migration_progress_reporter,
)
from .registry import MigrationRegistry
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
from .schema import SQLiteSchemaInspector
from .version_store import SQLiteUserVersionStore
__all__ = [
"BaseSchemaVersionDetector",
"BaseMigrationProgressReporter",
"build_default_migration_registry",
"build_default_schema_version_resolver",
"ColumnSchema",
"create_database_migration_bootstrapper",
"create_rich_migration_progress_reporter",
"DatabaseMigrationConfigurationError",
"DatabaseMigrationError",
"DatabaseMigrationBootstrapper",
"DatabaseMigrationExecutionError",
"DatabaseMigrationManager",
"DatabaseMigrationPlanningError",
"DatabaseMigrationState",
"DatabaseMigrationVersionError",
"DatabaseSchemaSnapshot",
"EMPTY_SCHEMA_VERSION",
"LATEST_SCHEMA_VERSION",
"LEGACY_V1_SCHEMA_VERSION",
"MigrationExecutionContext",
"MigrationPlan",
"MigrationPlanner",
"MigrationRegistry",
"MigrationStep",
"MissingMigrationStepError",
"ResolvedSchemaVersion",
"RichMigrationProgressReporter",
"SchemaVersionResolver",
"SchemaVersionSource",
"SQLiteSchemaInspector",
"SQLiteUserVersionStore",
"TableSchema",
"UnrecognizedDatabaseSchemaError",
"UnsupportedMigrationDirectionError",
]

View File

@@ -0,0 +1,171 @@
"""数据库迁移启动桥接层。"""
from typing import Optional
from sqlalchemy.engine import Engine
from src.common.logger import get_logger
from .builtin import (
LATEST_SCHEMA_VERSION,
build_default_migration_registry,
build_default_schema_version_resolver,
)
from .exceptions import DatabaseMigrationExecutionError
from .manager import DatabaseMigrationManager
from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource
from .registry import MigrationRegistry
from .resolver import SchemaVersionResolver
from .version_store import SQLiteUserVersionStore
logger = get_logger("database_migration")
class DatabaseMigrationBootstrapper:
"""数据库迁移启动桥接器。
该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束:
1. 若数据库为空,则直接交给当前模型定义建出最新结构;
2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动;
3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移;
4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。
"""
def __init__(
self,
manager: DatabaseMigrationManager,
latest_schema_version: int = LATEST_SCHEMA_VERSION,
) -> None:
"""初始化数据库迁移启动桥接器。
Args:
manager: 数据库迁移编排器。
latest_schema_version: 当前代码支持的最新 schema 版本号。
"""
self.manager = manager
self.latest_schema_version = latest_schema_version
def prepare_database(self) -> DatabaseMigrationState:
"""为数据库初始化阶段准备迁移状态。
Returns:
DatabaseMigrationState: 迁移准备完成后的数据库状态。
Raises:
DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。
"""
with self.manager.engine.connect() as connection:
resolved_version = self.manager.resolver.resolve(connection)
if resolved_version.version > self.latest_schema_version:
raise DatabaseMigrationExecutionError(
"当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。"
f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}"
)
if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE:
logger.info(
"检测到空数据库,将直接根据当前模型创建最新结构。"
f" 目标版本={self.latest_schema_version}"
)
return self._build_noop_state(
current_version=resolved_version.version,
target_version=self.latest_schema_version,
resolved_state=resolved_version,
)
migration_state = self.manager.describe_state(target_version=self.latest_schema_version)
if not migration_state.requires_migration():
logger.info(
f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}"
)
return migration_state
logger.info(
"检测到数据库需要迁移,"
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
)
self.manager.migrate(target_version=self.latest_schema_version)
return self.manager.describe_state(target_version=self.latest_schema_version)
def finalize_database(self, migration_state: DatabaseMigrationState) -> None:
"""在数据库初始化末尾补写最终 schema 版本号。
该方法主要负责两类场景:
1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本;
2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。
Args:
migration_state: 初始化前解析得到的迁移状态。
"""
if migration_state.requires_migration():
return
if migration_state.target_version <= 0:
return
if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA:
return
with self.manager.engine.begin() as connection:
self.manager.version_store.write_version(connection, migration_state.target_version)
logger.info(
"数据库 schema 版本写入完成。"
f" 来源={migration_state.resolved_version.source.value}"
f" 写入版本={migration_state.target_version}"
)
def _build_noop_state(
self,
current_version: int,
target_version: int,
resolved_state: ResolvedSchemaVersion,
) -> DatabaseMigrationState:
"""构建无迁移动作的数据库状态对象。
Args:
current_version: 当前数据库版本号。
target_version: 当前初始化流程期望达到的目标版本号。
resolved_state: 已解析的数据库版本状态。
Returns:
DatabaseMigrationState: 不包含迁移步骤的状态对象。
"""
return DatabaseMigrationState(
resolved_version=resolved_state,
target_version=target_version,
plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]),
)
def create_database_migration_bootstrapper(
engine: Engine,
registry: Optional[MigrationRegistry] = None,
resolver: Optional[SchemaVersionResolver] = None,
version_store: Optional[SQLiteUserVersionStore] = None,
latest_schema_version: int = LATEST_SCHEMA_VERSION,
) -> DatabaseMigrationBootstrapper:
"""创建数据库迁移启动桥接器。
Args:
engine: 目标数据库引擎。
registry: 迁移步骤注册表;未提供时使用默认注册表。
resolver: 数据库版本解析器;未提供时使用默认解析器。
version_store: 版本存储器;未提供时使用默认存储器。
latest_schema_version: 当前代码支持的最新 schema 版本号。
Returns:
DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。
"""
migration_registry = registry or build_default_migration_registry()
migration_resolver = resolver or build_default_schema_version_resolver()
migration_version_store = version_store or SQLiteUserVersionStore()
migration_manager = DatabaseMigrationManager(
engine=engine,
registry=migration_registry,
resolver=migration_resolver,
version_store=migration_version_store,
)
return DatabaseMigrationBootstrapper(
manager=migration_manager,
latest_schema_version=latest_schema_version,
)

View File

@@ -0,0 +1,159 @@
"""数据库迁移内置版本与默认注册表。"""
from typing import List, Optional
from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
from .models import DatabaseSchemaSnapshot, MigrationStep
from .registry import MigrationRegistry
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
from .version_store import SQLiteUserVersionStore
from .schema import SQLiteSchemaInspector
EMPTY_SCHEMA_VERSION = 0
LEGACY_V1_SCHEMA_VERSION = 1
LATEST_SCHEMA_VERSION = 2
_LEGACY_V1_EXCLUSIVE_TABLES = (
"chat_streams",
"emoji",
"emoji_description_cache",
"expression",
"group_info",
"image_descriptions",
"jargon",
"messages",
"thinking_back",
)
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
"""当前最新 schema 结构探测器。"""
@property
def name(self) -> str:
"""返回探测器名称。
Returns:
str: 当前探测器名称。
"""
return "latest_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""检测数据库是否已经是当前最新结构。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return None
latest_marker_tables = (
"mai_messages",
"chat_sessions",
"expressions",
"jargons",
"thinking_questions",
"tool_records",
)
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
return None
if not snapshot.has_column("images", "image_hash"):
return None
if not snapshot.has_column("images", "full_path"):
return None
if not snapshot.has_column("images", "image_type"):
return None
if not snapshot.has_column("action_records", "session_id"):
return None
if not snapshot.has_column("chat_history", "session_id"):
return None
if not snapshot.has_column("person_info", "user_nickname"):
return None
return LATEST_SCHEMA_VERSION
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
"""旧版 ``0.x`` schema 结构探测器。"""
@property
def name(self) -> str:
"""返回探测器名称。
Returns:
str: 当前探测器名称。
"""
return "legacy_v1_schema_detector"
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""检测数据库是否为旧版 ``0.x`` 结构。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
"""
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
return LEGACY_V1_SCHEMA_VERSION
legacy_shared_markers = (
("action_records", ("chat_id", "time")),
("chat_history", ("chat_id", "original_text")),
("images", ("emoji_hash", "path", "type")),
("llm_usage", ("model_api_provider", "status")),
("online_time", ("duration",)),
("person_info", ("nickname", "group_nick_name")),
)
for table_name, required_columns in legacy_shared_markers:
if snapshot.has_table(table_name) and all(
snapshot.has_column(table_name, column_name) for column_name in required_columns
):
return LEGACY_V1_SCHEMA_VERSION
return None
def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
"""构建默认 schema 版本探测器链。
Returns:
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
"""
return [
LatestSchemaVersionDetector(),
LegacyV1SchemaDetector(),
]
def build_default_schema_version_resolver() -> SchemaVersionResolver:
"""构建默认 schema 版本解析器。
Returns:
SchemaVersionResolver: 配置完成的 schema 版本解析器。
"""
return SchemaVersionResolver(
version_store=SQLiteUserVersionStore(),
schema_inspector=SQLiteSchemaInspector(),
detectors=build_default_schema_version_detectors(),
)
def build_default_migration_registry() -> MigrationRegistry:
"""构建默认迁移步骤注册表。
Returns:
MigrationRegistry: 含默认迁移步骤的注册表实例。
"""
return MigrationRegistry(
steps=[
MigrationStep(
version_from=LEGACY_V1_SCHEMA_VERSION,
version_to=LATEST_SCHEMA_VERSION,
name="legacy_v1_to_latest_v2",
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
handler=migrate_legacy_v1_to_v2,
)
]
)

View File

@@ -0,0 +1,33 @@
"""数据库迁移基础设施异常定义。"""
class DatabaseMigrationError(Exception):
"""数据库迁移基础异常。"""
class DatabaseMigrationConfigurationError(DatabaseMigrationError):
"""数据库迁移配置不合法。"""
class DatabaseMigrationPlanningError(DatabaseMigrationError):
"""数据库迁移计划生成失败。"""
class DatabaseMigrationExecutionError(DatabaseMigrationError):
"""数据库迁移执行失败。"""
class DatabaseMigrationVersionError(DatabaseMigrationError):
"""数据库版本读写或校验失败。"""
class MissingMigrationStepError(DatabaseMigrationPlanningError):
"""缺少某个版本区间所需的迁移步骤。"""
class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError):
"""当前迁移方向不被支持。"""
class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError):
"""无法识别未标记版本数据库的结构。"""

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,205 @@
"""数据库迁移编排器。"""
from typing import Callable, Optional
from sqlalchemy.engine import Connection, Engine
from src.common.logger import get_logger
from .exceptions import DatabaseMigrationExecutionError
from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan
from .planner import MigrationPlanner
from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter
from .registry import MigrationRegistry
from .resolver import SchemaVersionResolver
from .version_store import SQLiteUserVersionStore
logger = get_logger("database_migration")
class DatabaseMigrationManager:
"""数据库迁移编排器。
该类只负责基础设施层面的编排工作,包括:
1. 解析当前数据库版本;
2. 生成迁移计划;
3. 顺序执行已注册迁移步骤;
4. 在每一步成功后更新 ``user_version``。
当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。
"""
def __init__(
self,
engine: Engine,
registry: Optional[MigrationRegistry] = None,
planner: Optional[MigrationPlanner] = None,
resolver: Optional[SchemaVersionResolver] = None,
version_store: Optional[SQLiteUserVersionStore] = None,
progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None,
) -> None:
"""初始化数据库迁移编排器。
Args:
engine: 目标数据库引擎。
registry: 迁移步骤注册表。
planner: 迁移计划生成器。
resolver: 数据库版本解析器。
version_store: 版本存储器。
progress_reporter_factory: 迁移进度上报器工厂。
"""
self.engine = engine
self.registry = registry or MigrationRegistry()
self.planner = planner or MigrationPlanner()
self.resolver = resolver or SchemaVersionResolver()
self.version_store = version_store or SQLiteUserVersionStore()
self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter
def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState:
"""描述当前数据库的迁移状态。
Args:
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
Returns:
DatabaseMigrationState: 当前数据库迁移状态。
"""
with self.engine.connect() as connection:
resolved_version = self.resolver.resolve(connection)
effective_target_version = self._resolve_target_version(target_version)
migration_plan = self.planner.plan(
current_version=resolved_version.version,
target_version=effective_target_version,
registry=self.registry,
)
return DatabaseMigrationState(
resolved_version=resolved_version,
target_version=effective_target_version,
plan=migration_plan,
)
def plan(self, target_version: Optional[int] = None) -> MigrationPlan:
"""生成当前数据库的迁移计划。
Args:
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
Returns:
MigrationPlan: 当前数据库对应的迁移计划。
"""
return self.describe_state(target_version=target_version).plan
def migrate(self, target_version: Optional[int] = None) -> MigrationPlan:
"""执行迁移计划。
注意:
若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该
版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式
副作用。
Args:
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
Returns:
MigrationPlan: 已执行的迁移计划。
Raises:
DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。
"""
migration_state = self.describe_state(target_version=target_version)
migration_plan = migration_state.plan
if migration_plan.is_empty():
logger.info("数据库迁移计划为空,跳过执行。")
return migration_plan
current_version = migration_state.resolved_version.version
total_steps = migration_plan.step_count()
for step_index, step in enumerate(migration_plan.steps, start=1):
logger.info(
f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})"
)
try:
with self.progress_reporter_factory() as progress_reporter:
if step.transactional:
with self.engine.begin() as connection:
execution_context = self._build_execution_context(
connection=connection,
current_version=current_version,
migration_plan=migration_plan,
step_index=step_index,
step_name=step.name,
total_steps=total_steps,
progress_reporter=progress_reporter,
)
step.run(execution_context)
self.version_store.write_version(connection, step.version_to)
else:
with self.engine.connect() as connection:
execution_context = self._build_execution_context(
connection=connection,
current_version=current_version,
migration_plan=migration_plan,
step_index=step_index,
step_name=step.name,
total_steps=total_steps,
progress_reporter=progress_reporter,
)
step.run(execution_context)
self.version_store.write_version(connection, step.version_to)
connection.commit()
except Exception as exc:
raise DatabaseMigrationExecutionError(
f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。"
) from exc
current_version = step.version_to
logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}")
return migration_plan
def _resolve_target_version(self, target_version: Optional[int]) -> int:
"""解析最终目标版本号。
Args:
target_version: 调用方显式指定的目标版本。
Returns:
int: 最终用于规划和执行的目标版本号。
"""
if target_version is not None:
return target_version
return self.registry.latest_version()
def _build_execution_context(
self,
connection: Connection,
current_version: int,
migration_plan: MigrationPlan,
step_index: int,
step_name: str,
total_steps: int,
progress_reporter: BaseMigrationProgressReporter,
) -> MigrationExecutionContext:
"""构建单个迁移步骤的执行上下文。
Args:
connection: 当前迁移步骤使用的数据库连接。
current_version: 当前数据库版本。
migration_plan: 当前迁移计划。
step_index: 当前步骤序号,从 ``1`` 开始。
step_name: 当前步骤名称。
total_steps: 计划总步骤数。
progress_reporter: 当前步骤使用的进度上报器。
Returns:
MigrationExecutionContext: 当前步骤的执行上下文对象。
"""
return MigrationExecutionContext(
connection=connection,
current_version=current_version,
target_version=migration_plan.target_version,
step_index=step_index,
step_name=step_name,
total_steps=total_steps,
progress_reporter=progress_reporter,
)

View File

@@ -0,0 +1,285 @@
"""数据库迁移基础设施核心数据模型。"""
from __future__ import annotations
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
from sqlalchemy.engine import Connection
if TYPE_CHECKING:
from .progress import BaseMigrationProgressReporter
def _utc_now() -> datetime:
"""返回当前 UTC 时间。
Returns:
datetime: 当前 UTC 时间。
"""
return datetime.now(timezone.utc)
class SchemaVersionSource(str, Enum):
"""数据库版本来源。"""
PRAGMA = "pragma"
DETECTOR = "detector"
EMPTY_DATABASE = "empty_database"
@dataclass(frozen=True)
class ColumnSchema:
"""数据库列结构快照。"""
name: str
declared_type: str
default_value: Optional[str]
is_not_null: bool
primary_key_position: int
@dataclass(frozen=True)
class TableSchema:
"""数据库表结构快照。"""
name: str
columns: Dict[str, ColumnSchema]
def has_column(self, column_name: str) -> bool:
"""判断表中是否存在指定列。
Args:
column_name: 待检查的列名。
Returns:
bool: 若列存在则返回 ``True``,否则返回 ``False``。
"""
return column_name in self.columns
def get_column(self, column_name: str) -> Optional[ColumnSchema]:
"""获取指定列的结构信息。
Args:
column_name: 待获取的列名。
Returns:
Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。
"""
return self.columns.get(column_name)
def column_names(self) -> List[str]:
"""返回当前表中全部列名。
Returns:
List[str]: 按字母顺序排列的列名列表。
"""
return sorted(self.columns)
@dataclass(frozen=True)
class DatabaseSchemaSnapshot:
"""数据库结构快照。"""
tables: Dict[str, TableSchema]
def is_empty(self) -> bool:
"""判断数据库是否没有任何用户表。
Returns:
bool: 若数据库中没有用户表则返回 ``True``。
"""
return not self.tables
def has_table(self, table_name: str) -> bool:
"""判断数据库是否存在指定表。
Args:
table_name: 待检查的表名。
Returns:
bool: 若表存在则返回 ``True``,否则返回 ``False``。
"""
return table_name in self.tables
def has_column(self, table_name: str, column_name: str) -> bool:
"""判断数据库指定表中是否存在指定列。
Args:
table_name: 待检查的表名。
column_name: 待检查的列名。
Returns:
bool: 若表和列均存在则返回 ``True``。
"""
table_schema = self.get_table(table_name)
if table_schema is None:
return False
return table_schema.has_column(column_name)
def get_table(self, table_name: str) -> Optional[TableSchema]:
"""获取指定表的结构信息。
Args:
table_name: 待获取的表名。
Returns:
Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。
"""
return self.tables.get(table_name)
def table_names(self) -> List[str]:
"""返回当前数据库中的全部用户表名。
Returns:
List[str]: 按字母顺序排列的表名列表。
"""
return sorted(self.tables)
@dataclass(frozen=True)
class ResolvedSchemaVersion:
"""解析后的数据库版本信息。"""
version: int
source: SchemaVersionSource
detector_name: Optional[str] = None
snapshot: Optional[DatabaseSchemaSnapshot] = None
@dataclass(frozen=True)
class MigrationExecutionContext:
"""单个迁移步骤的执行上下文。"""
connection: Connection
current_version: int
target_version: int
step_index: int
step_name: str
total_steps: int
started_at: datetime = field(default_factory=_utc_now)
progress_reporter: Optional["BaseMigrationProgressReporter"] = None
def is_last_step(self) -> bool:
"""判断当前步骤是否为最后一步。
Returns:
bool: 若当前步骤已是计划中的最后一步则返回 ``True``。
"""
return self.step_index >= self.total_steps
def start_progress(
self,
total: int,
description: str = "总迁移进度",
unit_name: str = "",
) -> None:
"""启动当前迁移步骤的进度展示。
Args:
total: 当前步骤需要处理的总项目数。
description: 进度描述文本。
unit_name: 进度单位名称。
"""
if self.progress_reporter is None:
return
self.progress_reporter.start(total=total, description=description, unit_name=unit_name)
def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None:
"""推进当前迁移步骤的进度展示。
Args:
advance: 本次推进的步数。
item_name: 当前完成的项目名称。
"""
if self.progress_reporter is None:
return
self.progress_reporter.advance(advance=advance, item_name=item_name)
MigrationHandler = Callable[[MigrationExecutionContext], None]
@dataclass(frozen=True)
class MigrationStep:
"""单个数据库迁移步骤定义。"""
version_from: int
version_to: int
name: str
description: str
handler: MigrationHandler
transactional: bool = True
def __post_init__(self) -> None:
"""校验迁移步骤定义是否合法。
Raises:
ValueError: 当版本号不合法或迁移方向错误时抛出。
"""
if self.version_from < 0:
raise ValueError("迁移起始版本不能小于 0。")
if self.version_to <= self.version_from:
raise ValueError("迁移目标版本必须大于起始版本。")
def run(self, context: MigrationExecutionContext) -> None:
"""执行当前迁移步骤。
Args:
context: 当前迁移步骤的执行上下文。
"""
self.handler(context)
@dataclass(frozen=True)
class MigrationPlan:
"""数据库迁移执行计划。"""
current_version: int
target_version: int
steps: List[MigrationStep]
def is_empty(self) -> bool:
"""判断迁移计划是否为空。
Returns:
bool: 若无需执行任何迁移步骤则返回 ``True``。
"""
return not self.steps
def step_count(self) -> int:
"""返回迁移计划中的步骤数量。
Returns:
int: 当前计划中的迁移步骤数。
"""
return len(self.steps)
def latest_reachable_version(self) -> int:
"""返回该计划执行后的最终版本。
Returns:
int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。
"""
if self.is_empty():
return self.current_version
return self.steps[-1].version_to
@dataclass(frozen=True)
class DatabaseMigrationState:
"""数据库迁移状态描述。"""
resolved_version: ResolvedSchemaVersion
target_version: int
plan: MigrationPlan
def requires_migration(self) -> bool:
"""判断当前状态是否需要执行迁移。
Returns:
bool: 若计划中存在待执行迁移步骤则返回 ``True``。
"""
return not self.plan.is_empty()

View File

@@ -0,0 +1,108 @@
"""数据库迁移计划生成器。"""
from typing import List
from .exceptions import (
DatabaseMigrationPlanningError,
MissingMigrationStepError,
UnsupportedMigrationDirectionError,
)
from .models import MigrationPlan, MigrationStep
from .registry import MigrationRegistry
class MigrationPlanner:
"""数据库迁移计划生成器。"""
def plan(
self,
current_version: int,
target_version: int,
registry: MigrationRegistry,
) -> MigrationPlan:
"""根据当前版本与目标版本生成迁移计划。
Args:
current_version: 当前数据库版本。
target_version: 目标数据库版本。
registry: 迁移步骤注册表。
Returns:
MigrationPlan: 按顺序执行的迁移计划。
Raises:
DatabaseMigrationPlanningError: 当版本号非法时抛出。
MissingMigrationStepError: 当所需迁移步骤缺失时抛出。
UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。
"""
self._validate_version(current_version, "current_version")
self._validate_version(target_version, "target_version")
if target_version < current_version:
raise UnsupportedMigrationDirectionError(
f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}"
)
if target_version == current_version:
return MigrationPlan(current_version=current_version, target_version=target_version, steps=[])
steps = self._build_steps(current_version, target_version, registry)
return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps)
def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan:
"""生成迁移到注册表最新版本的执行计划。
Args:
current_version: 当前数据库版本。
registry: 迁移步骤注册表。
Returns:
MigrationPlan: 指向最新版本的迁移计划。
"""
target_version = registry.latest_version()
return self.plan(current_version=current_version, target_version=target_version, registry=registry)
def _build_steps(
self,
current_version: int,
target_version: int,
registry: MigrationRegistry,
) -> List[MigrationStep]:
"""按顺序拼装迁移步骤链。
Args:
current_version: 当前数据库版本。
target_version: 目标数据库版本。
registry: 迁移步骤注册表。
Returns:
List[MigrationStep]: 按顺序执行的迁移步骤列表。
Raises:
MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。
"""
planned_steps: List[MigrationStep] = []
next_version = current_version
while next_version < target_version:
step = registry.get_step(next_version)
if step is None:
raise MissingMigrationStepError(
f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。"
)
planned_steps.append(step)
next_version = step.version_to
return planned_steps
def _validate_version(self, version: int, field_name: str) -> None:
"""校验版本号是否合法。
Args:
version: 待校验的版本号。
field_name: 当前版本号对应的字段名。
Raises:
DatabaseMigrationPlanningError: 当版本号非法时抛出。
"""
if version < 0:
raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}")

View File

@@ -0,0 +1,272 @@
"""数据库迁移进度展示工具。"""
from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import timedelta
from typing import Optional
from rich.console import Console
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID
from rich.text import Text
def _format_duration(total_seconds: Optional[float]) -> str:
"""将秒数格式化为适合展示的耗时文本。
Args:
total_seconds: 总秒数;为空时表示暂不可用。
Returns:
str: 格式化后的耗时文本。
"""
if total_seconds is None:
return "--:--:--"
safe_seconds = max(total_seconds, 0.0)
return str(timedelta(seconds=int(safe_seconds)))
class MigrationSummaryColumn(ProgressColumn):
"""渲染数据库迁移总进度摘要列。"""
def render(self, task: Task) -> Text:
"""渲染当前任务的总进度摘要。
Args:
task: 当前进度任务对象。
Returns:
Text: 渲染后的摘要文本。
"""
display_total = task.fields.get("display_total", task.total)
total_text = "?" if display_total is None else str(int(display_total))
completed_text = str(int(task.completed))
return Text(f"总迁移进度({completed_text}/{total_text}")
class MigrationSpeedColumn(ProgressColumn):
"""渲染数据库迁移速度列。"""
def render(self, task: Task) -> Text:
"""渲染当前任务的速度信息。
Args:
task: 当前进度任务对象。
Returns:
Text: 渲染后的速度文本。
"""
unit_name = str(task.fields.get("unit_name", ""))
if task.speed is None or task.speed <= 0:
return Text(f"-- {unit_name}/s")
return Text(f"{task.speed:.2f} {unit_name}/s")
class MigrationElapsedColumn(ProgressColumn):
"""渲染数据库迁移已用时间列。"""
def render(self, task: Task) -> Text:
"""渲染当前任务的已用时间。
Args:
task: 当前进度任务对象。
Returns:
Text: 渲染后的已用时间文本。
"""
return Text(f"已用时间 {_format_duration(task.elapsed)}")
class MigrationRemainingColumn(ProgressColumn):
"""渲染数据库迁移预估剩余时间列。"""
def render(self, task: Task) -> Text:
"""渲染当前任务的预估剩余时间。
Args:
task: 当前进度任务对象。
Returns:
Text: 渲染后的预估剩余时间文本。
"""
return Text(f"预估时间 {_format_duration(task.time_remaining)}")
class BaseMigrationProgressReporter(ABC):
"""数据库迁移进度上报器基类。"""
def __enter__(self) -> "BaseMigrationProgressReporter":
"""进入进度上报上下文。
Returns:
BaseMigrationProgressReporter: 当前上报器实例。
"""
self.open()
return self
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""退出进度上报上下文。
Args:
exc_type: 异常类型。
exc_value: 异常实例。
traceback: 异常追踪对象。
"""
del exc_type, exc_value, traceback
self.close()
@abstractmethod
def open(self) -> None:
"""打开进度上报资源。"""
@abstractmethod
def close(self) -> None:
"""关闭进度上报资源。"""
@abstractmethod
def start(
self,
total: int,
description: str = "总迁移进度",
unit_name: str = "",
) -> None:
"""启动一个新的迁移进度任务。
Args:
total: 任务总数。
description: 任务描述。
unit_name: 进度单位名称。
"""
@abstractmethod
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
"""推进当前迁移进度任务。
Args:
advance: 本次推进的步数。
item_name: 当前完成的项目名称。
"""
class NullMigrationProgressReporter(BaseMigrationProgressReporter):
"""不输出任何内容的空进度上报器。"""
def close(self) -> None:
"""关闭空进度上报器。"""
def start(
self,
total: int,
description: str = "总迁移进度",
unit_name: str = "",
) -> None:
"""启动空进度任务。
Args:
total: 任务总数。
description: 任务描述。
unit_name: 进度单位名称。
"""
del total, description, unit_name
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
"""推进空进度任务。
Args:
advance: 本次推进的步数。
item_name: 当前完成的项目名称。
"""
del advance, item_name
class RichMigrationProgressReporter(BaseMigrationProgressReporter):
"""基于 ``rich`` 的数据库迁移进度上报器。"""
def __init__(
self,
console: Optional[Console] = None,
disable: Optional[bool] = None,
refresh_per_second: int = 10,
) -> None:
"""初始化 ``rich`` 迁移进度上报器。
Args:
console: 输出使用的 ``rich`` 控制台。
disable: 是否禁用进度条;为空时根据终端能力自动判断。
refresh_per_second: 每秒刷新次数。
"""
self.console = console or Console()
self.disable = disable
self.refresh_per_second = refresh_per_second
self._progress: Optional[Progress] = None
self._task_id: Optional[TaskID] = None
def open(self) -> None:
"""打开 ``rich`` 进度条资源。"""
effective_disable = not self.console.is_terminal if self.disable is None else self.disable
self._progress = Progress(
MigrationSummaryColumn(),
BarColumn(),
MigrationSpeedColumn(),
MigrationElapsedColumn(),
MigrationRemainingColumn(),
console=self.console,
transient=False,
disable=effective_disable,
refresh_per_second=self.refresh_per_second,
expand=True,
)
self._progress.start()
def close(self) -> None:
"""关闭 ``rich`` 进度条资源。"""
if self._progress is None:
return
self._progress.stop()
self._progress = None
self._task_id = None
def start(
self,
total: int,
description: str = "总迁移进度",
unit_name: str = "",
) -> None:
"""启动一个新的 ``rich`` 迁移进度任务。
Args:
total: 任务总数。
description: 任务描述。
unit_name: 进度单位名称。
"""
if self._progress is None:
self.open()
assert self._progress is not None
effective_total = max(total, 1)
self._task_id = self._progress.add_task(
description,
total=effective_total,
display_total=total,
unit_name=unit_name,
)
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
"""推进当前 ``rich`` 迁移进度任务。
Args:
advance: 本次推进的步数。
item_name: 当前完成的项目名称。
"""
del item_name
if self._progress is None or self._task_id is None:
return
self._progress.update(self._task_id, advance=advance)
def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter:
"""创建默认的 ``rich`` 迁移进度上报器。
Returns:
BaseMigrationProgressReporter: 默认迁移进度上报器实例。
"""
return RichMigrationProgressReporter()

View File

@@ -0,0 +1,98 @@
"""数据库迁移步骤注册表。"""
from typing import Dict, List, Optional
from .exceptions import DatabaseMigrationConfigurationError
from .models import MigrationStep
class MigrationRegistry:
"""数据库迁移步骤注册表。"""
def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None:
"""初始化迁移步骤注册表。
Args:
steps: 初始化时要注册的迁移步骤列表。
"""
self._steps_by_from_version: Dict[int, MigrationStep] = {}
if steps:
self.register_many(steps)
def register(self, step: MigrationStep) -> None:
"""注册单个迁移步骤。
当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、
易于回放,也便于后续生产问题排查。
Args:
step: 待注册的迁移步骤定义。
Raises:
DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。
"""
if step.version_to != step.version_from + 1:
raise DatabaseMigrationConfigurationError(
"迁移步骤必须使用相邻版本号定义,例如 2 -> 3。"
)
if step.version_from in self._steps_by_from_version:
existing_step = self._steps_by_from_version[step.version_from]
raise DatabaseMigrationConfigurationError(
f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}"
)
for registered_step in self._steps_by_from_version.values():
if registered_step.version_to == step.version_to:
raise DatabaseMigrationConfigurationError(
f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。"
)
self._steps_by_from_version[step.version_from] = step
def register_many(self, steps: List[MigrationStep]) -> None:
"""批量注册多个迁移步骤。
Args:
steps: 待注册的迁移步骤列表。
"""
for step in steps:
self.register(step)
def get_step(self, version_from: int) -> Optional[MigrationStep]:
"""获取指定起始版本的迁移步骤。
Args:
version_from: 迁移步骤的起始版本号。
Returns:
Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。
"""
return self._steps_by_from_version.get(version_from)
def has_step(self, version_from: int) -> bool:
"""判断指定起始版本是否已注册迁移步骤。
Args:
version_from: 待检查的起始版本号。
Returns:
bool: 若已注册对应步骤则返回 ``True``。
"""
return version_from in self._steps_by_from_version
def latest_version(self) -> int:
"""返回当前注册表支持到的最新 schema 版本。
Returns:
int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。
"""
if not self._steps_by_from_version:
return 0
return max(step.version_to for step in self._steps_by_from_version.values())
def list_steps(self) -> List[MigrationStep]:
"""按起始版本顺序返回全部迁移步骤。
Returns:
List[MigrationStep]: 已注册迁移步骤列表。
"""
ordered_versions = sorted(self._steps_by_from_version)
return [self._steps_by_from_version[version] for version in ordered_versions]

View File

@@ -0,0 +1,135 @@
"""数据库版本解析器。"""
from abc import ABC, abstractmethod
from typing import List, Optional
from sqlalchemy.engine import Connection
from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError
from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource
from .schema import SQLiteSchemaInspector
from .version_store import SQLiteUserVersionStore
class BaseSchemaVersionDetector(ABC):
"""未标记版本数据库的结构探测器基类。"""
@property
@abstractmethod
def name(self) -> str:
"""返回当前探测器名称。
Returns:
str: 当前探测器名称。
"""
@abstractmethod
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
"""根据数据库结构快照推断版本号。
Args:
snapshot: 当前数据库结构快照。
Returns:
Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。
"""
class SchemaVersionResolver:
"""数据库版本解析器。"""
def __init__(
self,
version_store: Optional[SQLiteUserVersionStore] = None,
schema_inspector: Optional[SQLiteSchemaInspector] = None,
detectors: Optional[List[BaseSchemaVersionDetector]] = None,
) -> None:
"""初始化数据库版本解析器。
Args:
version_store: 版本存储器;未提供时将使用默认实现。
schema_inspector: 结构探测器;未提供时将使用默认实现。
detectors: 未标记版本数据库的探测器列表。
"""
self.version_store = version_store or SQLiteUserVersionStore()
self.schema_inspector = schema_inspector or SQLiteSchemaInspector()
self.detectors: List[BaseSchemaVersionDetector] = list(detectors or [])
def add_detector(self, detector: BaseSchemaVersionDetector) -> None:
"""注册一个未标记版本数据库探测器。
Args:
detector: 待注册的探测器实例。
"""
self.detectors.append(detector)
def list_detectors(self) -> List[BaseSchemaVersionDetector]:
"""返回当前已注册的全部探测器。
Returns:
List[BaseSchemaVersionDetector]: 已注册探测器列表副本。
"""
return list(self.detectors)
def resolve(self, connection: Connection) -> ResolvedSchemaVersion:
"""解析当前数据库的 schema 版本信息。
解析顺序如下:
1. 优先读取 ``PRAGMA user_version``。
2. 若其值为 0则对数据库结构做快照。
3. 若数据库为空,则返回空库版本。
4. 若数据库非空,则交给探测器链进行识别。
Args:
connection: 当前数据库连接。
Returns:
ResolvedSchemaVersion: 解析后的数据库版本信息。
Raises:
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。
"""
recorded_version = self.version_store.read_version(connection)
if recorded_version > 0:
return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA)
snapshot = self.schema_inspector.inspect(connection)
if snapshot.is_empty():
return ResolvedSchemaVersion(
version=0,
source=SchemaVersionSource.EMPTY_DATABASE,
snapshot=snapshot,
)
return self._detect_unversioned_database(snapshot)
def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion:
"""识别未标记版本的历史数据库。
Args:
snapshot: 当前数据库结构快照。
Returns:
ResolvedSchemaVersion: 探测器识别出的版本信息。
Raises:
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。
"""
for detector in self.detectors:
detected_version = detector.detect_version(snapshot)
if detected_version is None:
continue
if detected_version < 0:
raise DatabaseMigrationVersionError(
f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}"
)
return ResolvedSchemaVersion(
version=detected_version,
source=SchemaVersionSource.DETECTOR,
detector_name=detector.name,
snapshot=snapshot,
)
raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。")

View File

@@ -0,0 +1,98 @@
"""SQLite 数据库结构探测工具。"""
from typing import Dict, List
from sqlalchemy import text
from sqlalchemy.engine import Connection
from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema
class SQLiteSchemaInspector:
"""SQLite 数据库结构探测器。"""
def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot:
"""提取数据库中的全部用户表结构快照。
Args:
connection: 当前数据库连接。
Returns:
DatabaseSchemaSnapshot: 当前数据库结构快照。
"""
tables: Dict[str, TableSchema] = {}
for table_name in self.list_user_tables(connection):
table_schema = self.get_table_schema(connection, table_name)
tables[table_name] = table_schema
return DatabaseSchemaSnapshot(tables=tables)
def list_user_tables(self, connection: Connection) -> List[str]:
"""列出数据库中的全部用户表。
Args:
connection: 当前数据库连接。
Returns:
List[str]: 按字母顺序排列的用户表名列表。
"""
statement = text(
"""
SELECT name
FROM sqlite_master
WHERE type = 'table'
AND name NOT LIKE 'sqlite_%'
ORDER BY name
"""
)
rows = connection.execute(statement).fetchall()
return [str(row[0]) for row in rows]
def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema:
"""获取指定表的结构信息。
Args:
connection: 当前数据库连接。
table_name: 待读取结构的表名。
Returns:
TableSchema: 指定表的结构快照。
"""
quoted_table_name = self._quote_identifier(table_name)
rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all()
columns: Dict[str, ColumnSchema] = {}
for row in rows:
column_schema = ColumnSchema(
name=str(row["name"]),
declared_type=str(row["type"] or ""),
default_value=None if row["dflt_value"] is None else str(row["dflt_value"]),
is_not_null=bool(row["notnull"]),
primary_key_position=int(row["pk"]),
)
columns[column_schema.name] = column_schema
return TableSchema(name=table_name, columns=columns)
def table_exists(self, connection: Connection, table_name: str) -> bool:
"""判断数据库中是否存在指定表。
Args:
connection: 当前数据库连接。
table_name: 待检查的表名。
Returns:
bool: 若表存在则返回 ``True``。
"""
return table_name in self.list_user_tables(connection)
def _quote_identifier(self, identifier: str) -> str:
"""为 SQLite 标识符添加安全引号。
Args:
identifier: 待引用的 SQLite 标识符。
Returns:
str: 可直接拼接到 PRAGMA 语句中的安全标识符。
"""
escaped_identifier = identifier.replace('"', '""')
return f'"{escaped_identifier}"'

View File

@@ -0,0 +1,57 @@
"""SQLite 数据库版本存储实现。"""
from sqlalchemy.engine import Connection
from .exceptions import DatabaseMigrationVersionError
class SQLiteUserVersionStore:
"""基于 ``PRAGMA user_version`` 的 SQLite 版本存储器。"""
def read_version(self, connection: Connection) -> int:
"""读取当前数据库的 schema 版本号。
Args:
connection: 当前数据库连接。
Returns:
int: 数据库记录的 schema 版本号。
Raises:
DatabaseMigrationVersionError: 当读取结果异常或版本号非法时抛出。
"""
row = connection.exec_driver_sql("PRAGMA user_version").first()
if row is None or len(row) == 0:
raise DatabaseMigrationVersionError("读取 SQLite user_version 失败,返回结果为空。")
version = row[0]
if not isinstance(version, int):
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不是整数: {version!r}")
if version < 0:
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不能为负数: {version}")
return version
def write_version(self, connection: Connection, version: int) -> None:
"""写入新的 schema 版本号。
Args:
connection: 当前数据库连接。
version: 待写入的 schema 版本号。
Raises:
DatabaseMigrationVersionError: 当版本号非法时抛出。
"""
self._validate_version(version)
connection.exec_driver_sql(f"PRAGMA user_version = {version}")
def _validate_version(self, version: int) -> None:
"""校验版本号是否合法。
Args:
version: 待校验的版本号。
Raises:
DatabaseMigrationVersionError: 当版本号非法时抛出。
"""
if version < 0:
raise DatabaseMigrationVersionError(f"SQLite user_version 不能小于 0: {version}")