feat(database-migrations): implement database migration manager and related components
- Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps. - Introduce models for migration state, execution context, and migration steps. - Implement MigrationPlanner to generate migration plans based on current and target versions. - Create MigrationRegistry for registering and managing migration steps. - Develop SchemaVersionResolver to determine the current database schema version. - Add SQLiteSchemaInspector for inspecting SQLite database structures. - Implement progress reporting tools using rich for visualizing migration progress. - Introduce SQLiteUserVersionStore for managing schema version storage in SQLite.
This commit is contained in:
912
pytests/common_test/test_database_migration_foundation.py
Normal file
912
pytests/common_test/test_database_migration_foundation.py
Normal file
@@ -0,0 +1,912 @@
|
||||
"""数据库迁移基础设施测试。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlmodel import SQLModel, create_engine
|
||||
|
||||
import json
|
||||
import msgpack
|
||||
import pytest
|
||||
|
||||
from src.common.database import database as database_module
|
||||
from src.common.database.migrations import (
|
||||
BaseSchemaVersionDetector,
|
||||
BaseMigrationProgressReporter,
|
||||
DatabaseSchemaSnapshot,
|
||||
DatabaseMigrationBootstrapper,
|
||||
DatabaseMigrationState,
|
||||
DatabaseMigrationManager,
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationRegistry,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionResolver,
|
||||
SchemaVersionSource,
|
||||
SQLiteSchemaInspector,
|
||||
SQLiteUserVersionStore,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
create_database_migration_bootstrapper,
|
||||
)
|
||||
|
||||
|
||||
class FixedVersionDetector(BaseSchemaVersionDetector):
|
||||
"""测试用固定版本探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回测试探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 探测器名称。
|
||||
"""
|
||||
return "fixed_version_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据测试表是否存在返回固定版本。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
||||
"""
|
||||
if snapshot.has_table("legacy_records"):
|
||||
return 2
|
||||
return None
|
||||
|
||||
|
||||
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""测试用迁移进度上报器。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""初始化测试用进度上报器。"""
|
||||
self.events: List[Tuple[str, Optional[int], Optional[str], Optional[str]]] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""记录打开事件。"""
|
||||
self.events.append(("open", None, None, None))
|
||||
|
||||
def close(self) -> None:
|
||||
"""记录关闭事件。"""
|
||||
self.events.append(("close", None, None, None))
|
||||
|
||||
def start(
|
||||
self,
|
||||
total: int,
|
||||
description: str = "总迁移进度",
|
||||
unit_name: str = "表",
|
||||
) -> None:
|
||||
"""记录启动事件。
|
||||
|
||||
Args:
|
||||
total: 任务总数。
|
||||
description: 任务描述。
|
||||
unit_name: 进度单位名称。
|
||||
"""
|
||||
self.events.append(("start", total, description, unit_name))
|
||||
|
||||
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||
"""记录推进事件。
|
||||
|
||||
Args:
|
||||
advance: 推进步数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
self.events.append(("advance", advance, item_name, None))
|
||||
|
||||
|
||||
def _create_sqlite_engine(database_file: Path) -> Engine:
|
||||
"""创建测试用 SQLite 引擎。
|
||||
|
||||
Args:
|
||||
database_file: 测试数据库文件路径。
|
||||
|
||||
Returns:
|
||||
Engine: SQLite 引擎实例。
|
||||
"""
|
||||
return create_engine(
|
||||
f"sqlite:///{database_file}",
|
||||
echo=False,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
|
||||
def _create_current_schema(connection: Connection) -> None:
|
||||
"""创建当前最新版本的数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
SQLModel.metadata.create_all(connection)
|
||||
|
||||
|
||||
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
||||
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
"""
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE chat_streams (
|
||||
id INTEGER PRIMARY KEY,
|
||||
stream_id TEXT NOT NULL,
|
||||
create_time REAL NOT NULL,
|
||||
last_active_time REAL NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
user_id TEXT,
|
||||
group_id TEXT,
|
||||
group_name TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE messages (
|
||||
id INTEGER PRIMARY KEY,
|
||||
message_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
chat_info_platform TEXT,
|
||||
user_id TEXT,
|
||||
user_nickname TEXT,
|
||||
chat_info_group_id TEXT,
|
||||
chat_info_group_name TEXT,
|
||||
is_mentioned INTEGER,
|
||||
is_at INTEGER,
|
||||
processed_plain_text TEXT,
|
||||
display_message TEXT,
|
||||
is_emoji INTEGER,
|
||||
is_picid INTEGER,
|
||||
is_command INTEGER,
|
||||
is_notify INTEGER,
|
||||
additional_config TEXT,
|
||||
priority_mode TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE action_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
action_id TEXT NOT NULL,
|
||||
time REAL NOT NULL,
|
||||
action_reasoning TEXT,
|
||||
action_name TEXT NOT NULL,
|
||||
action_data TEXT,
|
||||
action_prompt_display TEXT,
|
||||
chat_id TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE expression (
|
||||
id INTEGER PRIMARY KEY,
|
||||
situation TEXT NOT NULL,
|
||||
style TEXT NOT NULL,
|
||||
content_list TEXT,
|
||||
count INTEGER,
|
||||
last_active_time REAL NOT NULL,
|
||||
chat_id TEXT,
|
||||
create_date REAL,
|
||||
checked INTEGER,
|
||||
rejected INTEGER,
|
||||
modified_by TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE jargon (
|
||||
id INTEGER PRIMARY KEY,
|
||||
content TEXT NOT NULL,
|
||||
raw_content TEXT,
|
||||
meaning TEXT,
|
||||
chat_id TEXT,
|
||||
is_global INTEGER,
|
||||
count INTEGER,
|
||||
is_jargon INTEGER,
|
||||
last_inference_count INTEGER,
|
||||
is_complete INTEGER,
|
||||
inference_with_context TEXT,
|
||||
inference_content_only TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO chat_streams (
|
||||
id,
|
||||
stream_id,
|
||||
create_time,
|
||||
last_active_time,
|
||||
platform,
|
||||
user_id,
|
||||
group_id,
|
||||
group_name
|
||||
) VALUES (
|
||||
1,
|
||||
'session-1',
|
||||
1710000000.0,
|
||||
1710000300.0,
|
||||
'qq',
|
||||
'user-1',
|
||||
'group-1',
|
||||
'测试群'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO messages (
|
||||
id,
|
||||
message_id,
|
||||
time,
|
||||
chat_id,
|
||||
chat_info_platform,
|
||||
user_id,
|
||||
user_nickname,
|
||||
chat_info_group_id,
|
||||
chat_info_group_name,
|
||||
is_mentioned,
|
||||
is_at,
|
||||
processed_plain_text,
|
||||
display_message,
|
||||
is_emoji,
|
||||
is_picid,
|
||||
is_command,
|
||||
is_notify,
|
||||
additional_config,
|
||||
priority_mode
|
||||
) VALUES (
|
||||
1,
|
||||
'msg-1',
|
||||
1710000010.0,
|
||||
'session-1',
|
||||
'qq',
|
||||
'user-1',
|
||||
'测试用户',
|
||||
'group-1',
|
||||
'测试群',
|
||||
1,
|
||||
0,
|
||||
'你好',
|
||||
'你好呀',
|
||||
0,
|
||||
1,
|
||||
0,
|
||||
1,
|
||||
'{"source":"legacy"}',
|
||||
'high'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO action_records (
|
||||
id,
|
||||
action_id,
|
||||
time,
|
||||
action_reasoning,
|
||||
action_name,
|
||||
action_data,
|
||||
action_prompt_display,
|
||||
chat_id
|
||||
) VALUES (
|
||||
1,
|
||||
'action-1',
|
||||
1710000020.0,
|
||||
'需要调用工具',
|
||||
'search',
|
||||
'{"query":"MaiBot"}',
|
||||
'执行搜索',
|
||||
'session-1'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO expression (
|
||||
id,
|
||||
situation,
|
||||
style,
|
||||
content_list,
|
||||
count,
|
||||
last_active_time,
|
||||
chat_id,
|
||||
create_date,
|
||||
checked,
|
||||
rejected,
|
||||
modified_by
|
||||
) VALUES (
|
||||
1,
|
||||
'打招呼',
|
||||
'可爱',
|
||||
'["你好呀","早上好"]',
|
||||
3,
|
||||
1710000030.0,
|
||||
'session-1',
|
||||
1710000040.0,
|
||||
1,
|
||||
0,
|
||||
'ai'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO jargon (
|
||||
id,
|
||||
content,
|
||||
raw_content,
|
||||
meaning,
|
||||
chat_id,
|
||||
is_global,
|
||||
count,
|
||||
is_jargon,
|
||||
last_inference_count,
|
||||
is_complete,
|
||||
inference_with_context,
|
||||
inference_content_only
|
||||
) VALUES (
|
||||
1,
|
||||
'上分',
|
||||
'["上分"]',
|
||||
'提高排名',
|
||||
'session-1',
|
||||
0,
|
||||
5,
|
||||
1,
|
||||
2,
|
||||
1,
|
||||
'{"guess":"context"}',
|
||||
'{"guess":"content"}'
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
||||
"""应支持读取与写入 SQLite ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
||||
version_store = SQLiteUserVersionStore()
|
||||
|
||||
with engine.begin() as connection:
|
||||
assert version_store.read_version(connection) == 0
|
||||
version_store.write_version(connection, 7)
|
||||
|
||||
with engine.connect() as connection:
|
||||
assert version_store.read_version(connection) == 7
|
||||
|
||||
|
||||
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
||||
"""应能提取 SQLite 数据库的表与列结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
||||
inspector = SQLiteSchemaInspector()
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TABLE legacy_records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
payload TEXT NOT NULL,
|
||||
created_at TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = inspector.inspect(connection)
|
||||
|
||||
assert snapshot.has_table("legacy_records")
|
||||
assert snapshot.has_column("legacy_records", "payload")
|
||||
assert not snapshot.has_column("legacy_records", "missing_column")
|
||||
table_schema = snapshot.get_table("legacy_records")
|
||||
|
||||
assert table_schema is not None
|
||||
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
||||
|
||||
|
||||
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
||||
"""空数据库应被解析为版本 ``0``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
||||
resolver = SchemaVersionResolver()
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 0
|
||||
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
||||
assert resolved_version.snapshot is not None
|
||||
assert resolved_version.snapshot.is_empty()
|
||||
|
||||
|
||||
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
||||
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
||||
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == 2
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "fixed_version_detector"
|
||||
|
||||
|
||||
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
||||
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
||||
executed_steps: List[str] = []
|
||||
|
||||
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 0 -> 1。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
||||
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 1 -> 2。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
||||
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=0,
|
||||
version_to=1,
|
||||
name="create_sample_records",
|
||||
description="创建示例表。",
|
||||
handler=migrate_0_to_1,
|
||||
),
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="add_sample_email",
|
||||
description="为示例表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
),
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 2
|
||||
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
version_store = SQLiteUserVersionStore()
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = version_store.read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("sample_records")
|
||||
assert snapshot.has_column("sample_records", "email")
|
||||
|
||||
|
||||
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
||||
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
context.start_progress(total=3, description="总迁移进度", unit_name="表")
|
||||
context.advance_progress(item_name="chat_sessions")
|
||||
context.advance_progress(item_name="mai_messages")
|
||||
context.advance_progress(item_name="tool_records")
|
||||
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="progress_step",
|
||||
description="测试进度上报。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=registry,
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate()
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
assert reporter_instances[0].events == [
|
||||
("open", None, None, None),
|
||||
("start", 3, "总迁移进度", "表"),
|
||||
("advance", 1, "chat_sessions", None),
|
||||
("advance", 1, "mai_messages", None),
|
||||
("advance", 1, "tool_records", None),
|
||||
("close", None, None, None),
|
||||
]
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "latest_schema_detector"
|
||||
|
||||
|
||||
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
||||
resolver = build_default_schema_version_resolver()
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
with engine.connect() as connection:
|
||||
resolved_version = resolver.resolve(connection)
|
||||
|
||||
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
||||
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
||||
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
||||
"""空库在建表完成后应回写最新 ``user_version``。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
||||
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_current_schema(connection)
|
||||
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
||||
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
||||
execution_marks: List[str] = []
|
||||
|
||||
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||
"""测试桥接器迁移步骤 ``1 -> 2``。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤执行上下文。
|
||||
"""
|
||||
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
||||
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
||||
|
||||
with engine.begin() as connection:
|
||||
connection.execute(
|
||||
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
||||
)
|
||||
SQLiteUserVersionStore().write_version(connection, 1)
|
||||
|
||||
registry = MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=1,
|
||||
version_to=2,
|
||||
name="bootstrap_add_email",
|
||||
description="为桥接器测试表增加邮箱字段。",
|
||||
handler=migrate_1_to_2,
|
||||
)
|
||||
]
|
||||
)
|
||||
bootstrapper = DatabaseMigrationBootstrapper(
|
||||
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
||||
latest_schema_version=2,
|
||||
)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
|
||||
assert migration_state.resolved_version.version == 2
|
||||
assert migration_state.target_version == 2
|
||||
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
||||
|
||||
with engine.connect() as connection:
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
|
||||
assert recorded_version == 2
|
||||
assert snapshot.has_table("bootstrap_records")
|
||||
assert snapshot.has_column("bootstrap_records", "email")
|
||||
|
||||
|
||||
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
||||
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
||||
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
migration_state = bootstrapper.prepare_database()
|
||||
bootstrapper.finalize_database(migration_state)
|
||||
|
||||
assert not migration_state.requires_migration()
|
||||
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
||||
|
||||
with engine.connect() as connection:
|
||||
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||
message_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, processed_plain_text, additional_config, raw_content
|
||||
FROM mai_messages
|
||||
WHERE message_id = 'msg-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
action_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, action_name, action_display_prompt
|
||||
FROM action_records
|
||||
WHERE action_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
tool_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, tool_name, tool_display_prompt
|
||||
FROM tool_records
|
||||
WHERE tool_id = 'action-1'
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
expression_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id, content_list, modified_by
|
||||
FROM expressions
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
jargon_row = connection.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT session_id_dict, raw_content, inference_with_content_only
|
||||
FROM jargons
|
||||
WHERE id = 1
|
||||
"""
|
||||
)
|
||||
).mappings().one()
|
||||
|
||||
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||
assert snapshot.has_table("__legacy_v1_messages")
|
||||
assert snapshot.has_table("chat_sessions")
|
||||
assert snapshot.has_table("mai_messages")
|
||||
assert snapshot.has_table("tool_records")
|
||||
|
||||
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
||||
additional_config = json.loads(message_row["additional_config"])
|
||||
expression_content_list = json.loads(expression_row["content_list"])
|
||||
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
||||
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
||||
|
||||
assert message_row["session_id"] == "session-1"
|
||||
assert message_row["processed_plain_text"] == "你好"
|
||||
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
||||
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
||||
assert action_row["session_id"] == "session-1"
|
||||
assert action_row["action_name"] == "search"
|
||||
assert action_row["action_display_prompt"] == "执行搜索"
|
||||
assert tool_row["session_id"] == "session-1"
|
||||
assert tool_row["tool_name"] == "search"
|
||||
assert tool_row["tool_display_prompt"] == "执行搜索"
|
||||
assert expression_row["session_id"] == "session-1"
|
||||
assert expression_row["modified_by"] == "AI"
|
||||
assert expression_content_list == ["你好呀", "早上好"]
|
||||
assert jargon_session_id_dict == {"session-1": 5}
|
||||
assert jargon_raw_content == ["上分"]
|
||||
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
||||
|
||||
|
||||
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
||||
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
||||
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
||||
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||
|
||||
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||
"""构建测试用进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||
"""
|
||||
reporter = FakeMigrationProgressReporter()
|
||||
reporter_instances.append(reporter)
|
||||
return reporter
|
||||
|
||||
with engine.begin() as connection:
|
||||
_create_legacy_v1_schema_with_sample_data(connection)
|
||||
|
||||
manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=build_default_migration_registry(),
|
||||
resolver=build_default_schema_version_resolver(),
|
||||
progress_reporter_factory=_build_reporter,
|
||||
)
|
||||
|
||||
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
||||
|
||||
assert migration_plan.step_count() == 1
|
||||
assert len(reporter_instances) == 1
|
||||
reporter_events = reporter_instances[0].events
|
||||
|
||||
assert reporter_events[0] == ("open", None, None, None)
|
||||
assert reporter_events[1] == ("start", 12, "总迁移进度", "表")
|
||||
assert reporter_events[-1] == ("close", None, None, None)
|
||||
assert reporter_events.count(("advance", 1, "chat_sessions", None)) == 1
|
||||
assert reporter_events.count(("advance", 1, "thinking_questions", None)) == 1
|
||||
assert len([event for event in reporter_events if event[0] == "advance"]) == 12
|
||||
|
||||
|
||||
def test_initialize_database_calls_bootstrapper_before_create_all(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
||||
call_order: List[str] = []
|
||||
|
||||
def _fake_prepare_database() -> DatabaseMigrationState:
|
||||
"""返回测试用迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
||||
"""
|
||||
call_order.append("prepare_database")
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
plan=MigrationPlan(
|
||||
current_version=EMPTY_SCHEMA_VERSION,
|
||||
target_version=LATEST_SCHEMA_VERSION,
|
||||
steps=[],
|
||||
),
|
||||
)
|
||||
|
||||
def _fake_create_all(bind) -> None:
|
||||
"""记录建表调用。
|
||||
|
||||
Args:
|
||||
bind: 传入的数据库绑定对象。
|
||||
"""
|
||||
del bind
|
||||
call_order.append("create_all")
|
||||
|
||||
def _fake_migrate_action_records() -> None:
|
||||
"""记录轻量补迁移调用。"""
|
||||
call_order.append("migrate_action_records")
|
||||
|
||||
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
||||
"""记录迁移收尾调用。
|
||||
|
||||
Args:
|
||||
migration_state: 当前数据库迁移状态。
|
||||
"""
|
||||
del migration_state
|
||||
call_order.append("finalize_database")
|
||||
|
||||
monkeypatch.setattr(database_module, "_db_initialized", False)
|
||||
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
||||
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
||||
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
||||
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
|
||||
|
||||
database_module.initialize_database()
|
||||
|
||||
assert call_order == [
|
||||
"prepare_database",
|
||||
"create_all",
|
||||
"migrate_action_records",
|
||||
"finalize_database",
|
||||
]
|
||||
@@ -1,18 +1,23 @@
|
||||
from rich.traceback import install
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, TYPE_CHECKING
|
||||
from typing import ContextManager, Generator, TYPE_CHECKING
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import event, text
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
|
||||
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlite3 import Connection as SQLite3Connection
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||
@@ -53,6 +58,7 @@ SessionLocal = sessionmaker(
|
||||
bind=engine,
|
||||
class_=Session,
|
||||
)
|
||||
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
||||
|
||||
_db_initialized = False
|
||||
|
||||
@@ -93,14 +99,29 @@ def _migrate_action_records_to_tool_records() -> None:
|
||||
|
||||
|
||||
def initialize_database() -> None:
|
||||
"""初始化数据库连接、结构与启动期迁移。
|
||||
|
||||
当前初始化流程遵循以下顺序:
|
||||
1. 确保数据库目录存在;
|
||||
2. 加载 SQLModel 模型定义;
|
||||
3. 执行已注册的启动期迁移;
|
||||
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
||||
5. 执行项目现有的轻量数据补迁移逻辑。
|
||||
"""
|
||||
global _db_initialized
|
||||
if _db_initialized:
|
||||
return
|
||||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||
import src.common.database.database_model # noqa: F401
|
||||
|
||||
migration_state = _migration_bootstrapper.prepare_database()
|
||||
logger.info(
|
||||
"数据库迁移准备完成,"
|
||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||
)
|
||||
SQLModel.metadata.create_all(engine)
|
||||
_migrate_action_records_to_tool_records()
|
||||
_migration_bootstrapper.finalize_database(migration_state)
|
||||
_db_initialized = True
|
||||
|
||||
|
||||
@@ -150,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_db_session_manual():
|
||||
"""获取数据库会话的上下文管理器 (手动提交模式)。"""
|
||||
def get_db_session_manual() -> ContextManager[Session]:
|
||||
"""获取数据库会话的上下文管理器 (手动提交模式)。
|
||||
|
||||
Returns:
|
||||
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
|
||||
"""
|
||||
return get_db_session(auto_commit=False)
|
||||
|
||||
|
||||
|
||||
79
src/common/database/migrations/__init__.py
Normal file
79
src/common/database/migrations/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""数据库迁移基础设施导出模块。"""
|
||||
|
||||
from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper
|
||||
from .builtin import (
|
||||
EMPTY_SCHEMA_VERSION,
|
||||
LATEST_SCHEMA_VERSION,
|
||||
LEGACY_V1_SCHEMA_VERSION,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
)
|
||||
from .exceptions import (
|
||||
DatabaseMigrationConfigurationError,
|
||||
DatabaseMigrationError,
|
||||
DatabaseMigrationExecutionError,
|
||||
DatabaseMigrationPlanningError,
|
||||
DatabaseMigrationVersionError,
|
||||
MissingMigrationStepError,
|
||||
UnrecognizedDatabaseSchemaError,
|
||||
UnsupportedMigrationDirectionError,
|
||||
)
|
||||
from .manager import DatabaseMigrationManager
|
||||
from .models import (
|
||||
ColumnSchema,
|
||||
DatabaseMigrationState,
|
||||
DatabaseSchemaSnapshot,
|
||||
MigrationExecutionContext,
|
||||
MigrationPlan,
|
||||
MigrationStep,
|
||||
ResolvedSchemaVersion,
|
||||
SchemaVersionSource,
|
||||
TableSchema,
|
||||
)
|
||||
from .planner import MigrationPlanner
|
||||
from .progress import (
|
||||
BaseMigrationProgressReporter,
|
||||
RichMigrationProgressReporter,
|
||||
create_rich_migration_progress_reporter,
|
||||
)
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
__all__ = [
|
||||
"BaseSchemaVersionDetector",
|
||||
"BaseMigrationProgressReporter",
|
||||
"build_default_migration_registry",
|
||||
"build_default_schema_version_resolver",
|
||||
"ColumnSchema",
|
||||
"create_database_migration_bootstrapper",
|
||||
"create_rich_migration_progress_reporter",
|
||||
"DatabaseMigrationConfigurationError",
|
||||
"DatabaseMigrationError",
|
||||
"DatabaseMigrationBootstrapper",
|
||||
"DatabaseMigrationExecutionError",
|
||||
"DatabaseMigrationManager",
|
||||
"DatabaseMigrationPlanningError",
|
||||
"DatabaseMigrationState",
|
||||
"DatabaseMigrationVersionError",
|
||||
"DatabaseSchemaSnapshot",
|
||||
"EMPTY_SCHEMA_VERSION",
|
||||
"LATEST_SCHEMA_VERSION",
|
||||
"LEGACY_V1_SCHEMA_VERSION",
|
||||
"MigrationExecutionContext",
|
||||
"MigrationPlan",
|
||||
"MigrationPlanner",
|
||||
"MigrationRegistry",
|
||||
"MigrationStep",
|
||||
"MissingMigrationStepError",
|
||||
"ResolvedSchemaVersion",
|
||||
"RichMigrationProgressReporter",
|
||||
"SchemaVersionResolver",
|
||||
"SchemaVersionSource",
|
||||
"SQLiteSchemaInspector",
|
||||
"SQLiteUserVersionStore",
|
||||
"TableSchema",
|
||||
"UnrecognizedDatabaseSchemaError",
|
||||
"UnsupportedMigrationDirectionError",
|
||||
]
|
||||
171
src/common/database/migrations/bootstrap.py
Normal file
171
src/common/database/migrations/bootstrap.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""数据库迁移启动桥接层。"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .builtin import (
|
||||
LATEST_SCHEMA_VERSION,
|
||||
build_default_migration_registry,
|
||||
build_default_schema_version_resolver,
|
||||
)
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .manager import DatabaseMigrationManager
|
||||
from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
|
||||
class DatabaseMigrationBootstrapper:
|
||||
"""数据库迁移启动桥接器。
|
||||
|
||||
该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束:
|
||||
1. 若数据库为空,则直接交给当前模型定义建出最新结构;
|
||||
2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动;
|
||||
3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移;
|
||||
4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: DatabaseMigrationManager,
|
||||
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||
) -> None:
|
||||
"""初始化数据库迁移启动桥接器。
|
||||
|
||||
Args:
|
||||
manager: 数据库迁移编排器。
|
||||
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||
"""
|
||||
self.manager = manager
|
||||
self.latest_schema_version = latest_schema_version
|
||||
|
||||
def prepare_database(self) -> DatabaseMigrationState:
|
||||
"""为数据库初始化阶段准备迁移状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 迁移准备完成后的数据库状态。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。
|
||||
"""
|
||||
with self.manager.engine.connect() as connection:
|
||||
resolved_version = self.manager.resolver.resolve(connection)
|
||||
|
||||
if resolved_version.version > self.latest_schema_version:
|
||||
raise DatabaseMigrationExecutionError(
|
||||
"当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。"
|
||||
f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}"
|
||||
)
|
||||
|
||||
if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE:
|
||||
logger.info(
|
||||
"检测到空数据库,将直接根据当前模型创建最新结构。"
|
||||
f" 目标版本={self.latest_schema_version}"
|
||||
)
|
||||
return self._build_noop_state(
|
||||
current_version=resolved_version.version,
|
||||
target_version=self.latest_schema_version,
|
||||
resolved_state=resolved_version,
|
||||
)
|
||||
|
||||
migration_state = self.manager.describe_state(target_version=self.latest_schema_version)
|
||||
if not migration_state.requires_migration():
|
||||
logger.info(
|
||||
f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}"
|
||||
)
|
||||
return migration_state
|
||||
|
||||
logger.info(
|
||||
"检测到数据库需要迁移,"
|
||||
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||
)
|
||||
self.manager.migrate(target_version=self.latest_schema_version)
|
||||
return self.manager.describe_state(target_version=self.latest_schema_version)
|
||||
|
||||
def finalize_database(self, migration_state: DatabaseMigrationState) -> None:
|
||||
"""在数据库初始化末尾补写最终 schema 版本号。
|
||||
|
||||
该方法主要负责两类场景:
|
||||
1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本;
|
||||
2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。
|
||||
|
||||
Args:
|
||||
migration_state: 初始化前解析得到的迁移状态。
|
||||
"""
|
||||
if migration_state.requires_migration():
|
||||
return
|
||||
if migration_state.target_version <= 0:
|
||||
return
|
||||
if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA:
|
||||
return
|
||||
|
||||
with self.manager.engine.begin() as connection:
|
||||
self.manager.version_store.write_version(connection, migration_state.target_version)
|
||||
|
||||
logger.info(
|
||||
"数据库 schema 版本写入完成。"
|
||||
f" 来源={migration_state.resolved_version.source.value},"
|
||||
f" 写入版本={migration_state.target_version}"
|
||||
)
|
||||
|
||||
def _build_noop_state(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
resolved_state: ResolvedSchemaVersion,
|
||||
) -> DatabaseMigrationState:
|
||||
"""构建无迁移动作的数据库状态对象。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本号。
|
||||
target_version: 当前初始化流程期望达到的目标版本号。
|
||||
resolved_state: 已解析的数据库版本状态。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 不包含迁移步骤的状态对象。
|
||||
"""
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=resolved_state,
|
||||
target_version=target_version,
|
||||
plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]),
|
||||
)
|
||||
|
||||
|
||||
def create_database_migration_bootstrapper(
|
||||
engine: Engine,
|
||||
registry: Optional[MigrationRegistry] = None,
|
||||
resolver: Optional[SchemaVersionResolver] = None,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||
) -> DatabaseMigrationBootstrapper:
|
||||
"""创建数据库迁移启动桥接器。
|
||||
|
||||
Args:
|
||||
engine: 目标数据库引擎。
|
||||
registry: 迁移步骤注册表;未提供时使用默认注册表。
|
||||
resolver: 数据库版本解析器;未提供时使用默认解析器。
|
||||
version_store: 版本存储器;未提供时使用默认存储器。
|
||||
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。
|
||||
"""
|
||||
migration_registry = registry or build_default_migration_registry()
|
||||
migration_resolver = resolver or build_default_schema_version_resolver()
|
||||
migration_version_store = version_store or SQLiteUserVersionStore()
|
||||
migration_manager = DatabaseMigrationManager(
|
||||
engine=engine,
|
||||
registry=migration_registry,
|
||||
resolver=migration_resolver,
|
||||
version_store=migration_version_store,
|
||||
)
|
||||
return DatabaseMigrationBootstrapper(
|
||||
manager=migration_manager,
|
||||
latest_schema_version=latest_schema_version,
|
||||
)
|
||||
159
src/common/database/migrations/builtin.py
Normal file
159
src/common/database/migrations/builtin.py
Normal file
@@ -0,0 +1,159 @@
|
||||
"""数据库迁移内置版本与默认注册表。"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
|
||||
from .models import DatabaseSchemaSnapshot, MigrationStep
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
from .schema import SQLiteSchemaInspector
|
||||
|
||||
EMPTY_SCHEMA_VERSION = 0
|
||||
LEGACY_V1_SCHEMA_VERSION = 1
|
||||
LATEST_SCHEMA_VERSION = 2
|
||||
|
||||
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||
"chat_streams",
|
||||
"emoji",
|
||||
"emoji_description_cache",
|
||||
"expression",
|
||||
"group_info",
|
||||
"image_descriptions",
|
||||
"jargon",
|
||||
"messages",
|
||||
"thinking_back",
|
||||
)
|
||||
|
||||
|
||||
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||
"""当前最新 schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
return "latest_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否已经是当前最新结构。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
||||
"""
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return None
|
||||
|
||||
latest_marker_tables = (
|
||||
"mai_messages",
|
||||
"chat_sessions",
|
||||
"expressions",
|
||||
"jargons",
|
||||
"thinking_questions",
|
||||
"tool_records",
|
||||
)
|
||||
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_hash"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "full_path"):
|
||||
return None
|
||||
if not snapshot.has_column("images", "image_type"):
|
||||
return None
|
||||
if not snapshot.has_column("action_records", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("chat_history", "session_id"):
|
||||
return None
|
||||
if not snapshot.has_column("person_info", "user_nickname"):
|
||||
return None
|
||||
return LATEST_SCHEMA_VERSION
|
||||
|
||||
|
||||
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||
"""旧版 ``0.x`` schema 结构探测器。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""返回探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
return "legacy_v1_schema_detector"
|
||||
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""检测数据库是否为旧版 ``0.x`` 结构。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
||||
"""
|
||||
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||
return LEGACY_V1_SCHEMA_VERSION
|
||||
|
||||
legacy_shared_markers = (
|
||||
("action_records", ("chat_id", "time")),
|
||||
("chat_history", ("chat_id", "original_text")),
|
||||
("images", ("emoji_hash", "path", "type")),
|
||||
("llm_usage", ("model_api_provider", "status")),
|
||||
("online_time", ("duration",)),
|
||||
("person_info", ("nickname", "group_nick_name")),
|
||||
)
|
||||
for table_name, required_columns in legacy_shared_markers:
|
||||
if snapshot.has_table(table_name) and all(
|
||||
snapshot.has_column(table_name, column_name) for column_name in required_columns
|
||||
):
|
||||
return LEGACY_V1_SCHEMA_VERSION
|
||||
return None
|
||||
|
||||
|
||||
def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
||||
"""构建默认 schema 版本探测器链。
|
||||
|
||||
Returns:
|
||||
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
||||
"""
|
||||
return [
|
||||
LatestSchemaVersionDetector(),
|
||||
LegacyV1SchemaDetector(),
|
||||
]
|
||||
|
||||
|
||||
def build_default_schema_version_resolver() -> SchemaVersionResolver:
|
||||
"""构建默认 schema 版本解析器。
|
||||
|
||||
Returns:
|
||||
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
||||
"""
|
||||
return SchemaVersionResolver(
|
||||
version_store=SQLiteUserVersionStore(),
|
||||
schema_inspector=SQLiteSchemaInspector(),
|
||||
detectors=build_default_schema_version_detectors(),
|
||||
)
|
||||
|
||||
|
||||
def build_default_migration_registry() -> MigrationRegistry:
|
||||
"""构建默认迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
||||
"""
|
||||
return MigrationRegistry(
|
||||
steps=[
|
||||
MigrationStep(
|
||||
version_from=LEGACY_V1_SCHEMA_VERSION,
|
||||
version_to=LATEST_SCHEMA_VERSION,
|
||||
name="legacy_v1_to_latest_v2",
|
||||
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
|
||||
handler=migrate_legacy_v1_to_v2,
|
||||
)
|
||||
]
|
||||
)
|
||||
33
src/common/database/migrations/exceptions.py
Normal file
33
src/common/database/migrations/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
||||
"""数据库迁移基础设施异常定义。"""
|
||||
|
||||
|
||||
class DatabaseMigrationError(Exception):
|
||||
"""数据库迁移基础异常。"""
|
||||
|
||||
|
||||
class DatabaseMigrationConfigurationError(DatabaseMigrationError):
|
||||
"""数据库迁移配置不合法。"""
|
||||
|
||||
|
||||
class DatabaseMigrationPlanningError(DatabaseMigrationError):
|
||||
"""数据库迁移计划生成失败。"""
|
||||
|
||||
|
||||
class DatabaseMigrationExecutionError(DatabaseMigrationError):
|
||||
"""数据库迁移执行失败。"""
|
||||
|
||||
|
||||
class DatabaseMigrationVersionError(DatabaseMigrationError):
|
||||
"""数据库版本读写或校验失败。"""
|
||||
|
||||
|
||||
class MissingMigrationStepError(DatabaseMigrationPlanningError):
|
||||
"""缺少某个版本区间所需的迁移步骤。"""
|
||||
|
||||
|
||||
class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError):
|
||||
"""当前迁移方向不被支持。"""
|
||||
|
||||
|
||||
class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError):
|
||||
"""无法识别未标记版本数据库的结构。"""
|
||||
1384
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
1384
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/common/database/migrations/manager.py
Normal file
205
src/common/database/migrations/manager.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""数据库迁移编排器。"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .exceptions import DatabaseMigrationExecutionError
|
||||
from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan
|
||||
from .planner import MigrationPlanner
|
||||
from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter
|
||||
from .registry import MigrationRegistry
|
||||
from .resolver import SchemaVersionResolver
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
logger = get_logger("database_migration")
|
||||
|
||||
|
||||
class DatabaseMigrationManager:
|
||||
"""数据库迁移编排器。
|
||||
|
||||
该类只负责基础设施层面的编排工作,包括:
|
||||
1. 解析当前数据库版本;
|
||||
2. 生成迁移计划;
|
||||
3. 顺序执行已注册迁移步骤;
|
||||
4. 在每一步成功后更新 ``user_version``。
|
||||
|
||||
当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
registry: Optional[MigrationRegistry] = None,
|
||||
planner: Optional[MigrationPlanner] = None,
|
||||
resolver: Optional[SchemaVersionResolver] = None,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None,
|
||||
) -> None:
|
||||
"""初始化数据库迁移编排器。
|
||||
|
||||
Args:
|
||||
engine: 目标数据库引擎。
|
||||
registry: 迁移步骤注册表。
|
||||
planner: 迁移计划生成器。
|
||||
resolver: 数据库版本解析器。
|
||||
version_store: 版本存储器。
|
||||
progress_reporter_factory: 迁移进度上报器工厂。
|
||||
"""
|
||||
self.engine = engine
|
||||
self.registry = registry or MigrationRegistry()
|
||||
self.planner = planner or MigrationPlanner()
|
||||
self.resolver = resolver or SchemaVersionResolver()
|
||||
self.version_store = version_store or SQLiteUserVersionStore()
|
||||
self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter
|
||||
|
||||
def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState:
|
||||
"""描述当前数据库的迁移状态。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
DatabaseMigrationState: 当前数据库迁移状态。
|
||||
"""
|
||||
with self.engine.connect() as connection:
|
||||
resolved_version = self.resolver.resolve(connection)
|
||||
|
||||
effective_target_version = self._resolve_target_version(target_version)
|
||||
migration_plan = self.planner.plan(
|
||||
current_version=resolved_version.version,
|
||||
target_version=effective_target_version,
|
||||
registry=self.registry,
|
||||
)
|
||||
return DatabaseMigrationState(
|
||||
resolved_version=resolved_version,
|
||||
target_version=effective_target_version,
|
||||
plan=migration_plan,
|
||||
)
|
||||
|
||||
def plan(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||
"""生成当前数据库的迁移计划。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 当前数据库对应的迁移计划。
|
||||
"""
|
||||
return self.describe_state(target_version=target_version).plan
|
||||
|
||||
def migrate(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||
"""执行迁移计划。
|
||||
|
||||
注意:
|
||||
若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该
|
||||
版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式
|
||||
副作用。
|
||||
|
||||
Args:
|
||||
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 已执行的迁移计划。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。
|
||||
"""
|
||||
migration_state = self.describe_state(target_version=target_version)
|
||||
migration_plan = migration_state.plan
|
||||
if migration_plan.is_empty():
|
||||
logger.info("数据库迁移计划为空,跳过执行。")
|
||||
return migration_plan
|
||||
|
||||
current_version = migration_state.resolved_version.version
|
||||
total_steps = migration_plan.step_count()
|
||||
for step_index, step in enumerate(migration_plan.steps, start=1):
|
||||
logger.info(
|
||||
f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})"
|
||||
)
|
||||
try:
|
||||
with self.progress_reporter_factory() as progress_reporter:
|
||||
if step.transactional:
|
||||
with self.engine.begin() as connection:
|
||||
execution_context = self._build_execution_context(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
migration_plan=migration_plan,
|
||||
step_index=step_index,
|
||||
step_name=step.name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
step.run(execution_context)
|
||||
self.version_store.write_version(connection, step.version_to)
|
||||
else:
|
||||
with self.engine.connect() as connection:
|
||||
execution_context = self._build_execution_context(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
migration_plan=migration_plan,
|
||||
step_index=step_index,
|
||||
step_name=step.name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
step.run(execution_context)
|
||||
self.version_store.write_version(connection, step.version_to)
|
||||
connection.commit()
|
||||
except Exception as exc:
|
||||
raise DatabaseMigrationExecutionError(
|
||||
f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。"
|
||||
) from exc
|
||||
current_version = step.version_to
|
||||
logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}")
|
||||
|
||||
return migration_plan
|
||||
|
||||
def _resolve_target_version(self, target_version: Optional[int]) -> int:
|
||||
"""解析最终目标版本号。
|
||||
|
||||
Args:
|
||||
target_version: 调用方显式指定的目标版本。
|
||||
|
||||
Returns:
|
||||
int: 最终用于规划和执行的目标版本号。
|
||||
"""
|
||||
if target_version is not None:
|
||||
return target_version
|
||||
return self.registry.latest_version()
|
||||
|
||||
def _build_execution_context(
|
||||
self,
|
||||
connection: Connection,
|
||||
current_version: int,
|
||||
migration_plan: MigrationPlan,
|
||||
step_index: int,
|
||||
step_name: str,
|
||||
total_steps: int,
|
||||
progress_reporter: BaseMigrationProgressReporter,
|
||||
) -> MigrationExecutionContext:
|
||||
"""构建单个迁移步骤的执行上下文。
|
||||
|
||||
Args:
|
||||
connection: 当前迁移步骤使用的数据库连接。
|
||||
current_version: 当前数据库版本。
|
||||
migration_plan: 当前迁移计划。
|
||||
step_index: 当前步骤序号,从 ``1`` 开始。
|
||||
step_name: 当前步骤名称。
|
||||
total_steps: 计划总步骤数。
|
||||
progress_reporter: 当前步骤使用的进度上报器。
|
||||
|
||||
Returns:
|
||||
MigrationExecutionContext: 当前步骤的执行上下文对象。
|
||||
"""
|
||||
return MigrationExecutionContext(
|
||||
connection=connection,
|
||||
current_version=current_version,
|
||||
target_version=migration_plan.target_version,
|
||||
step_index=step_index,
|
||||
step_name=step_name,
|
||||
total_steps=total_steps,
|
||||
progress_reporter=progress_reporter,
|
||||
)
|
||||
285
src/common/database/migrations/models.py
Normal file
285
src/common/database/migrations/models.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""数据库迁移基础设施核心数据模型。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .progress import BaseMigrationProgressReporter
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""返回当前 UTC 时间。
|
||||
|
||||
Returns:
|
||||
datetime: 当前 UTC 时间。
|
||||
"""
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
class SchemaVersionSource(str, Enum):
|
||||
"""数据库版本来源。"""
|
||||
|
||||
PRAGMA = "pragma"
|
||||
DETECTOR = "detector"
|
||||
EMPTY_DATABASE = "empty_database"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ColumnSchema:
|
||||
"""数据库列结构快照。"""
|
||||
|
||||
name: str
|
||||
declared_type: str
|
||||
default_value: Optional[str]
|
||||
is_not_null: bool
|
||||
primary_key_position: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TableSchema:
|
||||
"""数据库表结构快照。"""
|
||||
|
||||
name: str
|
||||
columns: Dict[str, ColumnSchema]
|
||||
|
||||
def has_column(self, column_name: str) -> bool:
|
||||
"""判断表中是否存在指定列。
|
||||
|
||||
Args:
|
||||
column_name: 待检查的列名。
|
||||
|
||||
Returns:
|
||||
bool: 若列存在则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
return column_name in self.columns
|
||||
|
||||
def get_column(self, column_name: str) -> Optional[ColumnSchema]:
|
||||
"""获取指定列的结构信息。
|
||||
|
||||
Args:
|
||||
column_name: 待获取的列名。
|
||||
|
||||
Returns:
|
||||
Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。
|
||||
"""
|
||||
return self.columns.get(column_name)
|
||||
|
||||
def column_names(self) -> List[str]:
|
||||
"""返回当前表中全部列名。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的列名列表。
|
||||
"""
|
||||
return sorted(self.columns)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseSchemaSnapshot:
|
||||
"""数据库结构快照。"""
|
||||
|
||||
tables: Dict[str, TableSchema]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""判断数据库是否没有任何用户表。
|
||||
|
||||
Returns:
|
||||
bool: 若数据库中没有用户表则返回 ``True``。
|
||||
"""
|
||||
return not self.tables
|
||||
|
||||
def has_table(self, table_name: str) -> bool:
|
||||
"""判断数据库是否存在指定表。
|
||||
|
||||
Args:
|
||||
table_name: 待检查的表名。
|
||||
|
||||
Returns:
|
||||
bool: 若表存在则返回 ``True``,否则返回 ``False``。
|
||||
"""
|
||||
return table_name in self.tables
|
||||
|
||||
def has_column(self, table_name: str, column_name: str) -> bool:
|
||||
"""判断数据库指定表中是否存在指定列。
|
||||
|
||||
Args:
|
||||
table_name: 待检查的表名。
|
||||
column_name: 待检查的列名。
|
||||
|
||||
Returns:
|
||||
bool: 若表和列均存在则返回 ``True``。
|
||||
"""
|
||||
table_schema = self.get_table(table_name)
|
||||
if table_schema is None:
|
||||
return False
|
||||
return table_schema.has_column(column_name)
|
||||
|
||||
def get_table(self, table_name: str) -> Optional[TableSchema]:
|
||||
"""获取指定表的结构信息。
|
||||
|
||||
Args:
|
||||
table_name: 待获取的表名。
|
||||
|
||||
Returns:
|
||||
Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。
|
||||
"""
|
||||
return self.tables.get(table_name)
|
||||
|
||||
def table_names(self) -> List[str]:
|
||||
"""返回当前数据库中的全部用户表名。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的表名列表。
|
||||
"""
|
||||
return sorted(self.tables)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ResolvedSchemaVersion:
|
||||
"""解析后的数据库版本信息。"""
|
||||
|
||||
version: int
|
||||
source: SchemaVersionSource
|
||||
detector_name: Optional[str] = None
|
||||
snapshot: Optional[DatabaseSchemaSnapshot] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationExecutionContext:
|
||||
"""单个迁移步骤的执行上下文。"""
|
||||
|
||||
connection: Connection
|
||||
current_version: int
|
||||
target_version: int
|
||||
step_index: int
|
||||
step_name: str
|
||||
total_steps: int
|
||||
started_at: datetime = field(default_factory=_utc_now)
|
||||
progress_reporter: Optional["BaseMigrationProgressReporter"] = None
|
||||
|
||||
def is_last_step(self) -> bool:
|
||||
"""判断当前步骤是否为最后一步。
|
||||
|
||||
Returns:
|
||||
bool: 若当前步骤已是计划中的最后一步则返回 ``True``。
|
||||
"""
|
||||
return self.step_index >= self.total_steps
|
||||
|
||||
def start_progress(
|
||||
self,
|
||||
total: int,
|
||||
description: str = "总迁移进度",
|
||||
unit_name: str = "表",
|
||||
) -> None:
|
||||
"""启动当前迁移步骤的进度展示。
|
||||
|
||||
Args:
|
||||
total: 当前步骤需要处理的总项目数。
|
||||
description: 进度描述文本。
|
||||
unit_name: 进度单位名称。
|
||||
"""
|
||||
if self.progress_reporter is None:
|
||||
return
|
||||
self.progress_reporter.start(total=total, description=description, unit_name=unit_name)
|
||||
|
||||
def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||
"""推进当前迁移步骤的进度展示。
|
||||
|
||||
Args:
|
||||
advance: 本次推进的步数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
if self.progress_reporter is None:
|
||||
return
|
||||
self.progress_reporter.advance(advance=advance, item_name=item_name)
|
||||
|
||||
|
||||
MigrationHandler = Callable[[MigrationExecutionContext], None]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationStep:
|
||||
"""单个数据库迁移步骤定义。"""
|
||||
|
||||
version_from: int
|
||||
version_to: int
|
||||
name: str
|
||||
description: str
|
||||
handler: MigrationHandler
|
||||
transactional: bool = True
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""校验迁移步骤定义是否合法。
|
||||
|
||||
Raises:
|
||||
ValueError: 当版本号不合法或迁移方向错误时抛出。
|
||||
"""
|
||||
if self.version_from < 0:
|
||||
raise ValueError("迁移起始版本不能小于 0。")
|
||||
if self.version_to <= self.version_from:
|
||||
raise ValueError("迁移目标版本必须大于起始版本。")
|
||||
|
||||
def run(self, context: MigrationExecutionContext) -> None:
|
||||
"""执行当前迁移步骤。
|
||||
|
||||
Args:
|
||||
context: 当前迁移步骤的执行上下文。
|
||||
"""
|
||||
self.handler(context)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MigrationPlan:
|
||||
"""数据库迁移执行计划。"""
|
||||
|
||||
current_version: int
|
||||
target_version: int
|
||||
steps: List[MigrationStep]
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""判断迁移计划是否为空。
|
||||
|
||||
Returns:
|
||||
bool: 若无需执行任何迁移步骤则返回 ``True``。
|
||||
"""
|
||||
return not self.steps
|
||||
|
||||
def step_count(self) -> int:
|
||||
"""返回迁移计划中的步骤数量。
|
||||
|
||||
Returns:
|
||||
int: 当前计划中的迁移步骤数。
|
||||
"""
|
||||
return len(self.steps)
|
||||
|
||||
def latest_reachable_version(self) -> int:
|
||||
"""返回该计划执行后的最终版本。
|
||||
|
||||
Returns:
|
||||
int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。
|
||||
"""
|
||||
if self.is_empty():
|
||||
return self.current_version
|
||||
return self.steps[-1].version_to
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatabaseMigrationState:
|
||||
"""数据库迁移状态描述。"""
|
||||
|
||||
resolved_version: ResolvedSchemaVersion
|
||||
target_version: int
|
||||
plan: MigrationPlan
|
||||
|
||||
def requires_migration(self) -> bool:
|
||||
"""判断当前状态是否需要执行迁移。
|
||||
|
||||
Returns:
|
||||
bool: 若计划中存在待执行迁移步骤则返回 ``True``。
|
||||
"""
|
||||
return not self.plan.is_empty()
|
||||
108
src/common/database/migrations/planner.py
Normal file
108
src/common/database/migrations/planner.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""数据库迁移计划生成器。"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from .exceptions import (
|
||||
DatabaseMigrationPlanningError,
|
||||
MissingMigrationStepError,
|
||||
UnsupportedMigrationDirectionError,
|
||||
)
|
||||
from .models import MigrationPlan, MigrationStep
|
||||
from .registry import MigrationRegistry
|
||||
|
||||
|
||||
class MigrationPlanner:
|
||||
"""数据库迁移计划生成器。"""
|
||||
|
||||
def plan(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
registry: MigrationRegistry,
|
||||
) -> MigrationPlan:
|
||||
"""根据当前版本与目标版本生成迁移计划。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
target_version: 目标数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 按顺序执行的迁移计划。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||
MissingMigrationStepError: 当所需迁移步骤缺失时抛出。
|
||||
UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。
|
||||
"""
|
||||
self._validate_version(current_version, "current_version")
|
||||
self._validate_version(target_version, "target_version")
|
||||
|
||||
if target_version < current_version:
|
||||
raise UnsupportedMigrationDirectionError(
|
||||
f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}。"
|
||||
)
|
||||
if target_version == current_version:
|
||||
return MigrationPlan(current_version=current_version, target_version=target_version, steps=[])
|
||||
|
||||
steps = self._build_steps(current_version, target_version, registry)
|
||||
return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps)
|
||||
|
||||
def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan:
|
||||
"""生成迁移到注册表最新版本的执行计划。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
MigrationPlan: 指向最新版本的迁移计划。
|
||||
"""
|
||||
target_version = registry.latest_version()
|
||||
return self.plan(current_version=current_version, target_version=target_version, registry=registry)
|
||||
|
||||
def _build_steps(
|
||||
self,
|
||||
current_version: int,
|
||||
target_version: int,
|
||||
registry: MigrationRegistry,
|
||||
) -> List[MigrationStep]:
|
||||
"""按顺序拼装迁移步骤链。
|
||||
|
||||
Args:
|
||||
current_version: 当前数据库版本。
|
||||
target_version: 目标数据库版本。
|
||||
registry: 迁移步骤注册表。
|
||||
|
||||
Returns:
|
||||
List[MigrationStep]: 按顺序执行的迁移步骤列表。
|
||||
|
||||
Raises:
|
||||
MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。
|
||||
"""
|
||||
planned_steps: List[MigrationStep] = []
|
||||
next_version = current_version
|
||||
|
||||
while next_version < target_version:
|
||||
step = registry.get_step(next_version)
|
||||
if step is None:
|
||||
raise MissingMigrationStepError(
|
||||
f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。"
|
||||
)
|
||||
planned_steps.append(step)
|
||||
next_version = step.version_to
|
||||
|
||||
return planned_steps
|
||||
|
||||
def _validate_version(self, version: int, field_name: str) -> None:
|
||||
"""校验版本号是否合法。
|
||||
|
||||
Args:
|
||||
version: 待校验的版本号。
|
||||
field_name: 当前版本号对应的字段名。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||
"""
|
||||
if version < 0:
|
||||
raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}")
|
||||
272
src/common/database/migrations/progress.py
Normal file
272
src/common/database/migrations/progress.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""数据库迁移进度展示工具。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
|
||||
from rich.console import Console
|
||||
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
def _format_duration(total_seconds: Optional[float]) -> str:
|
||||
"""将秒数格式化为适合展示的耗时文本。
|
||||
|
||||
Args:
|
||||
total_seconds: 总秒数;为空时表示暂不可用。
|
||||
|
||||
Returns:
|
||||
str: 格式化后的耗时文本。
|
||||
"""
|
||||
if total_seconds is None:
|
||||
return "--:--:--"
|
||||
safe_seconds = max(total_seconds, 0.0)
|
||||
return str(timedelta(seconds=int(safe_seconds)))
|
||||
|
||||
|
||||
class MigrationSummaryColumn(ProgressColumn):
|
||||
"""渲染数据库迁移总进度摘要列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的总进度摘要。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的摘要文本。
|
||||
"""
|
||||
display_total = task.fields.get("display_total", task.total)
|
||||
total_text = "?" if display_total is None else str(int(display_total))
|
||||
completed_text = str(int(task.completed))
|
||||
return Text(f"总迁移进度({completed_text}/{total_text})")
|
||||
|
||||
|
||||
class MigrationSpeedColumn(ProgressColumn):
|
||||
"""渲染数据库迁移速度列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的速度信息。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的速度文本。
|
||||
"""
|
||||
unit_name = str(task.fields.get("unit_name", "项"))
|
||||
if task.speed is None or task.speed <= 0:
|
||||
return Text(f"-- {unit_name}/s")
|
||||
return Text(f"{task.speed:.2f} {unit_name}/s")
|
||||
|
||||
|
||||
class MigrationElapsedColumn(ProgressColumn):
|
||||
"""渲染数据库迁移已用时间列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的已用时间。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的已用时间文本。
|
||||
"""
|
||||
return Text(f"已用时间 {_format_duration(task.elapsed)}")
|
||||
|
||||
|
||||
class MigrationRemainingColumn(ProgressColumn):
|
||||
"""渲染数据库迁移预估剩余时间列。"""
|
||||
|
||||
def render(self, task: Task) -> Text:
|
||||
"""渲染当前任务的预估剩余时间。
|
||||
|
||||
Args:
|
||||
task: 当前进度任务对象。
|
||||
|
||||
Returns:
|
||||
Text: 渲染后的预估剩余时间文本。
|
||||
"""
|
||||
return Text(f"预估时间 {_format_duration(task.time_remaining)}")
|
||||
|
||||
|
||||
class BaseMigrationProgressReporter(ABC):
|
||||
"""数据库迁移进度上报器基类。"""
|
||||
|
||||
def __enter__(self) -> "BaseMigrationProgressReporter":
|
||||
"""进入进度上报上下文。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 当前上报器实例。
|
||||
"""
|
||||
self.open()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||
"""退出进度上报上下文。
|
||||
|
||||
Args:
|
||||
exc_type: 异常类型。
|
||||
exc_value: 异常实例。
|
||||
traceback: 异常追踪对象。
|
||||
"""
|
||||
del exc_type, exc_value, traceback
|
||||
self.close()
|
||||
|
||||
@abstractmethod
|
||||
def open(self) -> None:
|
||||
"""打开进度上报资源。"""
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""关闭进度上报资源。"""
|
||||
|
||||
@abstractmethod
|
||||
def start(
|
||||
self,
|
||||
total: int,
|
||||
description: str = "总迁移进度",
|
||||
unit_name: str = "表",
|
||||
) -> None:
|
||||
"""启动一个新的迁移进度任务。
|
||||
|
||||
Args:
|
||||
total: 任务总数。
|
||||
description: 任务描述。
|
||||
unit_name: 进度单位名称。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||
"""推进当前迁移进度任务。
|
||||
|
||||
Args:
|
||||
advance: 本次推进的步数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
|
||||
|
||||
class NullMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""不输出任何内容的空进度上报器。"""
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭空进度上报器。"""
|
||||
|
||||
def start(
|
||||
self,
|
||||
total: int,
|
||||
description: str = "总迁移进度",
|
||||
unit_name: str = "表",
|
||||
) -> None:
|
||||
"""启动空进度任务。
|
||||
|
||||
Args:
|
||||
total: 任务总数。
|
||||
description: 任务描述。
|
||||
unit_name: 进度单位名称。
|
||||
"""
|
||||
del total, description, unit_name
|
||||
|
||||
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||
"""推进空进度任务。
|
||||
|
||||
Args:
|
||||
advance: 本次推进的步数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
del advance, item_name
|
||||
|
||||
|
||||
class RichMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||
"""基于 ``rich`` 的数据库迁移进度上报器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
console: Optional[Console] = None,
|
||||
disable: Optional[bool] = None,
|
||||
refresh_per_second: int = 10,
|
||||
) -> None:
|
||||
"""初始化 ``rich`` 迁移进度上报器。
|
||||
|
||||
Args:
|
||||
console: 输出使用的 ``rich`` 控制台。
|
||||
disable: 是否禁用进度条;为空时根据终端能力自动判断。
|
||||
refresh_per_second: 每秒刷新次数。
|
||||
"""
|
||||
self.console = console or Console()
|
||||
self.disable = disable
|
||||
self.refresh_per_second = refresh_per_second
|
||||
self._progress: Optional[Progress] = None
|
||||
self._task_id: Optional[TaskID] = None
|
||||
|
||||
def open(self) -> None:
|
||||
"""打开 ``rich`` 进度条资源。"""
|
||||
effective_disable = not self.console.is_terminal if self.disable is None else self.disable
|
||||
self._progress = Progress(
|
||||
MigrationSummaryColumn(),
|
||||
BarColumn(),
|
||||
MigrationSpeedColumn(),
|
||||
MigrationElapsedColumn(),
|
||||
MigrationRemainingColumn(),
|
||||
console=self.console,
|
||||
transient=False,
|
||||
disable=effective_disable,
|
||||
refresh_per_second=self.refresh_per_second,
|
||||
expand=True,
|
||||
)
|
||||
self._progress.start()
|
||||
|
||||
def close(self) -> None:
|
||||
"""关闭 ``rich`` 进度条资源。"""
|
||||
if self._progress is None:
|
||||
return
|
||||
self._progress.stop()
|
||||
self._progress = None
|
||||
self._task_id = None
|
||||
|
||||
def start(
|
||||
self,
|
||||
total: int,
|
||||
description: str = "总迁移进度",
|
||||
unit_name: str = "表",
|
||||
) -> None:
|
||||
"""启动一个新的 ``rich`` 迁移进度任务。
|
||||
|
||||
Args:
|
||||
total: 任务总数。
|
||||
description: 任务描述。
|
||||
unit_name: 进度单位名称。
|
||||
"""
|
||||
if self._progress is None:
|
||||
self.open()
|
||||
assert self._progress is not None
|
||||
effective_total = max(total, 1)
|
||||
self._task_id = self._progress.add_task(
|
||||
description,
|
||||
total=effective_total,
|
||||
display_total=total,
|
||||
unit_name=unit_name,
|
||||
)
|
||||
|
||||
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||
"""推进当前 ``rich`` 迁移进度任务。
|
||||
|
||||
Args:
|
||||
advance: 本次推进的步数。
|
||||
item_name: 当前完成的项目名称。
|
||||
"""
|
||||
del item_name
|
||||
if self._progress is None or self._task_id is None:
|
||||
return
|
||||
self._progress.update(self._task_id, advance=advance)
|
||||
|
||||
|
||||
def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter:
|
||||
"""创建默认的 ``rich`` 迁移进度上报器。
|
||||
|
||||
Returns:
|
||||
BaseMigrationProgressReporter: 默认迁移进度上报器实例。
|
||||
"""
|
||||
return RichMigrationProgressReporter()
|
||||
98
src/common/database/migrations/registry.py
Normal file
98
src/common/database/migrations/registry.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""数据库迁移步骤注册表。"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from .exceptions import DatabaseMigrationConfigurationError
|
||||
from .models import MigrationStep
|
||||
|
||||
|
||||
class MigrationRegistry:
|
||||
"""数据库迁移步骤注册表。"""
|
||||
|
||||
def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None:
|
||||
"""初始化迁移步骤注册表。
|
||||
|
||||
Args:
|
||||
steps: 初始化时要注册的迁移步骤列表。
|
||||
"""
|
||||
self._steps_by_from_version: Dict[int, MigrationStep] = {}
|
||||
if steps:
|
||||
self.register_many(steps)
|
||||
|
||||
def register(self, step: MigrationStep) -> None:
|
||||
"""注册单个迁移步骤。
|
||||
|
||||
当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、
|
||||
易于回放,也便于后续生产问题排查。
|
||||
|
||||
Args:
|
||||
step: 待注册的迁移步骤定义。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。
|
||||
"""
|
||||
if step.version_to != step.version_from + 1:
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
"迁移步骤必须使用相邻版本号定义,例如 2 -> 3。"
|
||||
)
|
||||
if step.version_from in self._steps_by_from_version:
|
||||
existing_step = self._steps_by_from_version[step.version_from]
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}"
|
||||
)
|
||||
for registered_step in self._steps_by_from_version.values():
|
||||
if registered_step.version_to == step.version_to:
|
||||
raise DatabaseMigrationConfigurationError(
|
||||
f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。"
|
||||
)
|
||||
self._steps_by_from_version[step.version_from] = step
|
||||
|
||||
def register_many(self, steps: List[MigrationStep]) -> None:
|
||||
"""批量注册多个迁移步骤。
|
||||
|
||||
Args:
|
||||
steps: 待注册的迁移步骤列表。
|
||||
"""
|
||||
for step in steps:
|
||||
self.register(step)
|
||||
|
||||
def get_step(self, version_from: int) -> Optional[MigrationStep]:
|
||||
"""获取指定起始版本的迁移步骤。
|
||||
|
||||
Args:
|
||||
version_from: 迁移步骤的起始版本号。
|
||||
|
||||
Returns:
|
||||
Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。
|
||||
"""
|
||||
return self._steps_by_from_version.get(version_from)
|
||||
|
||||
def has_step(self, version_from: int) -> bool:
|
||||
"""判断指定起始版本是否已注册迁移步骤。
|
||||
|
||||
Args:
|
||||
version_from: 待检查的起始版本号。
|
||||
|
||||
Returns:
|
||||
bool: 若已注册对应步骤则返回 ``True``。
|
||||
"""
|
||||
return version_from in self._steps_by_from_version
|
||||
|
||||
def latest_version(self) -> int:
|
||||
"""返回当前注册表支持到的最新 schema 版本。
|
||||
|
||||
Returns:
|
||||
int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。
|
||||
"""
|
||||
if not self._steps_by_from_version:
|
||||
return 0
|
||||
return max(step.version_to for step in self._steps_by_from_version.values())
|
||||
|
||||
def list_steps(self) -> List[MigrationStep]:
|
||||
"""按起始版本顺序返回全部迁移步骤。
|
||||
|
||||
Returns:
|
||||
List[MigrationStep]: 已注册迁移步骤列表。
|
||||
"""
|
||||
ordered_versions = sorted(self._steps_by_from_version)
|
||||
return [self._steps_by_from_version[version] for version in ordered_versions]
|
||||
135
src/common/database/migrations/resolver.py
Normal file
135
src/common/database/migrations/resolver.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""数据库版本解析器。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError
|
||||
from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource
|
||||
from .schema import SQLiteSchemaInspector
|
||||
from .version_store import SQLiteUserVersionStore
|
||||
|
||||
|
||||
class BaseSchemaVersionDetector(ABC):
|
||||
"""未标记版本数据库的结构探测器基类。"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""返回当前探测器名称。
|
||||
|
||||
Returns:
|
||||
str: 当前探测器名称。
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||
"""根据数据库结构快照推断版本号。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。
|
||||
"""
|
||||
|
||||
|
||||
class SchemaVersionResolver:
|
||||
"""数据库版本解析器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||
schema_inspector: Optional[SQLiteSchemaInspector] = None,
|
||||
detectors: Optional[List[BaseSchemaVersionDetector]] = None,
|
||||
) -> None:
|
||||
"""初始化数据库版本解析器。
|
||||
|
||||
Args:
|
||||
version_store: 版本存储器;未提供时将使用默认实现。
|
||||
schema_inspector: 结构探测器;未提供时将使用默认实现。
|
||||
detectors: 未标记版本数据库的探测器列表。
|
||||
"""
|
||||
self.version_store = version_store or SQLiteUserVersionStore()
|
||||
self.schema_inspector = schema_inspector or SQLiteSchemaInspector()
|
||||
self.detectors: List[BaseSchemaVersionDetector] = list(detectors or [])
|
||||
|
||||
def add_detector(self, detector: BaseSchemaVersionDetector) -> None:
|
||||
"""注册一个未标记版本数据库探测器。
|
||||
|
||||
Args:
|
||||
detector: 待注册的探测器实例。
|
||||
"""
|
||||
self.detectors.append(detector)
|
||||
|
||||
def list_detectors(self) -> List[BaseSchemaVersionDetector]:
|
||||
"""返回当前已注册的全部探测器。
|
||||
|
||||
Returns:
|
||||
List[BaseSchemaVersionDetector]: 已注册探测器列表副本。
|
||||
"""
|
||||
return list(self.detectors)
|
||||
|
||||
def resolve(self, connection: Connection) -> ResolvedSchemaVersion:
|
||||
"""解析当前数据库的 schema 版本信息。
|
||||
|
||||
解析顺序如下:
|
||||
1. 优先读取 ``PRAGMA user_version``。
|
||||
2. 若其值为 0,则对数据库结构做快照。
|
||||
3. 若数据库为空,则返回空库版本。
|
||||
4. 若数据库非空,则交给探测器链进行识别。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
ResolvedSchemaVersion: 解析后的数据库版本信息。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||
UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。
|
||||
"""
|
||||
recorded_version = self.version_store.read_version(connection)
|
||||
if recorded_version > 0:
|
||||
return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA)
|
||||
|
||||
snapshot = self.schema_inspector.inspect(connection)
|
||||
if snapshot.is_empty():
|
||||
return ResolvedSchemaVersion(
|
||||
version=0,
|
||||
source=SchemaVersionSource.EMPTY_DATABASE,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
return self._detect_unversioned_database(snapshot)
|
||||
|
||||
def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion:
|
||||
"""识别未标记版本的历史数据库。
|
||||
|
||||
Args:
|
||||
snapshot: 当前数据库结构快照。
|
||||
|
||||
Returns:
|
||||
ResolvedSchemaVersion: 探测器识别出的版本信息。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||
UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。
|
||||
"""
|
||||
for detector in self.detectors:
|
||||
detected_version = detector.detect_version(snapshot)
|
||||
if detected_version is None:
|
||||
continue
|
||||
if detected_version < 0:
|
||||
raise DatabaseMigrationVersionError(
|
||||
f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}"
|
||||
)
|
||||
return ResolvedSchemaVersion(
|
||||
version=detected_version,
|
||||
source=SchemaVersionSource.DETECTOR,
|
||||
detector_name=detector.name,
|
||||
snapshot=snapshot,
|
||||
)
|
||||
|
||||
raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。")
|
||||
98
src/common/database/migrations/schema.py
Normal file
98
src/common/database/migrations/schema.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""SQLite 数据库结构探测工具。"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema
|
||||
|
||||
|
||||
class SQLiteSchemaInspector:
|
||||
"""SQLite 数据库结构探测器。"""
|
||||
|
||||
def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot:
|
||||
"""提取数据库中的全部用户表结构快照。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
DatabaseSchemaSnapshot: 当前数据库结构快照。
|
||||
"""
|
||||
tables: Dict[str, TableSchema] = {}
|
||||
for table_name in self.list_user_tables(connection):
|
||||
table_schema = self.get_table_schema(connection, table_name)
|
||||
tables[table_name] = table_schema
|
||||
return DatabaseSchemaSnapshot(tables=tables)
|
||||
|
||||
def list_user_tables(self, connection: Connection) -> List[str]:
|
||||
"""列出数据库中的全部用户表。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
List[str]: 按字母顺序排列的用户表名列表。
|
||||
"""
|
||||
statement = text(
|
||||
"""
|
||||
SELECT name
|
||||
FROM sqlite_master
|
||||
WHERE type = 'table'
|
||||
AND name NOT LIKE 'sqlite_%'
|
||||
ORDER BY name
|
||||
"""
|
||||
)
|
||||
rows = connection.execute(statement).fetchall()
|
||||
return [str(row[0]) for row in rows]
|
||||
|
||||
def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema:
|
||||
"""获取指定表的结构信息。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
table_name: 待读取结构的表名。
|
||||
|
||||
Returns:
|
||||
TableSchema: 指定表的结构快照。
|
||||
"""
|
||||
quoted_table_name = self._quote_identifier(table_name)
|
||||
rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all()
|
||||
|
||||
columns: Dict[str, ColumnSchema] = {}
|
||||
for row in rows:
|
||||
column_schema = ColumnSchema(
|
||||
name=str(row["name"]),
|
||||
declared_type=str(row["type"] or ""),
|
||||
default_value=None if row["dflt_value"] is None else str(row["dflt_value"]),
|
||||
is_not_null=bool(row["notnull"]),
|
||||
primary_key_position=int(row["pk"]),
|
||||
)
|
||||
columns[column_schema.name] = column_schema
|
||||
|
||||
return TableSchema(name=table_name, columns=columns)
|
||||
|
||||
def table_exists(self, connection: Connection, table_name: str) -> bool:
|
||||
"""判断数据库中是否存在指定表。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
table_name: 待检查的表名。
|
||||
|
||||
Returns:
|
||||
bool: 若表存在则返回 ``True``。
|
||||
"""
|
||||
return table_name in self.list_user_tables(connection)
|
||||
|
||||
def _quote_identifier(self, identifier: str) -> str:
|
||||
"""为 SQLite 标识符添加安全引号。
|
||||
|
||||
Args:
|
||||
identifier: 待引用的 SQLite 标识符。
|
||||
|
||||
Returns:
|
||||
str: 可直接拼接到 PRAGMA 语句中的安全标识符。
|
||||
"""
|
||||
escaped_identifier = identifier.replace('"', '""')
|
||||
return f'"{escaped_identifier}"'
|
||||
57
src/common/database/migrations/version_store.py
Normal file
57
src/common/database/migrations/version_store.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""SQLite 数据库版本存储实现。"""
|
||||
|
||||
from sqlalchemy.engine import Connection
|
||||
|
||||
from .exceptions import DatabaseMigrationVersionError
|
||||
|
||||
|
||||
class SQLiteUserVersionStore:
|
||||
"""基于 ``PRAGMA user_version`` 的 SQLite 版本存储器。"""
|
||||
|
||||
def read_version(self, connection: Connection) -> int:
|
||||
"""读取当前数据库的 schema 版本号。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
|
||||
Returns:
|
||||
int: 数据库记录的 schema 版本号。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当读取结果异常或版本号非法时抛出。
|
||||
"""
|
||||
row = connection.exec_driver_sql("PRAGMA user_version").first()
|
||||
if row is None or len(row) == 0:
|
||||
raise DatabaseMigrationVersionError("读取 SQLite user_version 失败,返回结果为空。")
|
||||
|
||||
version = row[0]
|
||||
if not isinstance(version, int):
|
||||
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不是整数: {version!r}")
|
||||
if version < 0:
|
||||
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不能为负数: {version}")
|
||||
return version
|
||||
|
||||
def write_version(self, connection: Connection, version: int) -> None:
|
||||
"""写入新的 schema 版本号。
|
||||
|
||||
Args:
|
||||
connection: 当前数据库连接。
|
||||
version: 待写入的 schema 版本号。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当版本号非法时抛出。
|
||||
"""
|
||||
self._validate_version(version)
|
||||
connection.exec_driver_sql(f"PRAGMA user_version = {version}")
|
||||
|
||||
def _validate_version(self, version: int) -> None:
|
||||
"""校验版本号是否合法。
|
||||
|
||||
Args:
|
||||
version: 待校验的版本号。
|
||||
|
||||
Raises:
|
||||
DatabaseMigrationVersionError: 当版本号非法时抛出。
|
||||
"""
|
||||
if version < 0:
|
||||
raise DatabaseMigrationVersionError(f"SQLite user_version 不能小于 0: {version}")
|
||||
Reference in New Issue
Block a user