feat(database-migrations): implement database migration manager and related components
- Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps. - Introduce models for migration state, execution context, and migration steps. - Implement MigrationPlanner to generate migration plans based on current and target versions. - Create MigrationRegistry for registering and managing migration steps. - Develop SchemaVersionResolver to determine the current database schema version. - Add SQLiteSchemaInspector for inspecting SQLite database structures. - Implement progress reporting tools using rich for visualizing migration progress. - Introduce SQLiteUserVersionStore for managing schema version storage in SQLite.
This commit is contained in:
912
pytests/common_test/test_database_migration_foundation.py
Normal file
912
pytests/common_test/test_database_migration_foundation.py
Normal file
@@ -0,0 +1,912 @@
|
|||||||
|
"""数据库迁移基础设施测试。"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Connection, Engine
|
||||||
|
from sqlmodel import SQLModel, create_engine
|
||||||
|
|
||||||
|
import json
|
||||||
|
import msgpack
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.common.database import database as database_module
|
||||||
|
from src.common.database.migrations import (
|
||||||
|
BaseSchemaVersionDetector,
|
||||||
|
BaseMigrationProgressReporter,
|
||||||
|
DatabaseSchemaSnapshot,
|
||||||
|
DatabaseMigrationBootstrapper,
|
||||||
|
DatabaseMigrationState,
|
||||||
|
DatabaseMigrationManager,
|
||||||
|
EMPTY_SCHEMA_VERSION,
|
||||||
|
LATEST_SCHEMA_VERSION,
|
||||||
|
LEGACY_V1_SCHEMA_VERSION,
|
||||||
|
MigrationExecutionContext,
|
||||||
|
MigrationPlan,
|
||||||
|
MigrationRegistry,
|
||||||
|
MigrationStep,
|
||||||
|
ResolvedSchemaVersion,
|
||||||
|
SchemaVersionResolver,
|
||||||
|
SchemaVersionSource,
|
||||||
|
SQLiteSchemaInspector,
|
||||||
|
SQLiteUserVersionStore,
|
||||||
|
build_default_migration_registry,
|
||||||
|
build_default_schema_version_resolver,
|
||||||
|
create_database_migration_bootstrapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FixedVersionDetector(BaseSchemaVersionDetector):
|
||||||
|
"""测试用固定版本探测器。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""返回测试探测器名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 探测器名称。
|
||||||
|
"""
|
||||||
|
return "fixed_version_detector"
|
||||||
|
|
||||||
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
|
"""根据测试表是否存在返回固定版本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
if snapshot.has_table("legacy_records"):
|
||||||
|
return 2
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class FakeMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||||
|
"""测试用迁移进度上报器。"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""初始化测试用进度上报器。"""
|
||||||
|
self.events: List[Tuple[str, Optional[int], Optional[str], Optional[str]]] = []
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
"""记录打开事件。"""
|
||||||
|
self.events.append(("open", None, None, None))
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""记录关闭事件。"""
|
||||||
|
self.events.append(("close", None, None, None))
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
total: int,
|
||||||
|
description: str = "总迁移进度",
|
||||||
|
unit_name: str = "表",
|
||||||
|
) -> None:
|
||||||
|
"""记录启动事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: 任务总数。
|
||||||
|
description: 任务描述。
|
||||||
|
unit_name: 进度单位名称。
|
||||||
|
"""
|
||||||
|
self.events.append(("start", total, description, unit_name))
|
||||||
|
|
||||||
|
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||||
|
"""记录推进事件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
advance: 推进步数。
|
||||||
|
item_name: 当前完成的项目名称。
|
||||||
|
"""
|
||||||
|
self.events.append(("advance", advance, item_name, None))
|
||||||
|
|
||||||
|
|
||||||
|
def _create_sqlite_engine(database_file: Path) -> Engine:
|
||||||
|
"""创建测试用 SQLite 引擎。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
database_file: 测试数据库文件路径。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Engine: SQLite 引擎实例。
|
||||||
|
"""
|
||||||
|
return create_engine(
|
||||||
|
f"sqlite:///{database_file}",
|
||||||
|
echo=False,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_current_schema(connection: Connection) -> None:
|
||||||
|
"""创建当前最新版本的数据库结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
"""
|
||||||
|
import src.common.database.database_model # noqa: F401
|
||||||
|
|
||||||
|
SQLModel.metadata.create_all(connection)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None:
|
||||||
|
"""创建带示例数据的旧版 ``0.x`` 数据库结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
"""
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE chat_streams (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
stream_id TEXT NOT NULL,
|
||||||
|
create_time REAL NOT NULL,
|
||||||
|
last_active_time REAL NOT NULL,
|
||||||
|
platform TEXT NOT NULL,
|
||||||
|
user_id TEXT,
|
||||||
|
group_id TEXT,
|
||||||
|
group_name TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE messages (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
message_id TEXT NOT NULL,
|
||||||
|
time REAL NOT NULL,
|
||||||
|
chat_id TEXT NOT NULL,
|
||||||
|
chat_info_platform TEXT,
|
||||||
|
user_id TEXT,
|
||||||
|
user_nickname TEXT,
|
||||||
|
chat_info_group_id TEXT,
|
||||||
|
chat_info_group_name TEXT,
|
||||||
|
is_mentioned INTEGER,
|
||||||
|
is_at INTEGER,
|
||||||
|
processed_plain_text TEXT,
|
||||||
|
display_message TEXT,
|
||||||
|
is_emoji INTEGER,
|
||||||
|
is_picid INTEGER,
|
||||||
|
is_command INTEGER,
|
||||||
|
is_notify INTEGER,
|
||||||
|
additional_config TEXT,
|
||||||
|
priority_mode TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE action_records (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
action_id TEXT NOT NULL,
|
||||||
|
time REAL NOT NULL,
|
||||||
|
action_reasoning TEXT,
|
||||||
|
action_name TEXT NOT NULL,
|
||||||
|
action_data TEXT,
|
||||||
|
action_prompt_display TEXT,
|
||||||
|
chat_id TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE expression (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
situation TEXT NOT NULL,
|
||||||
|
style TEXT NOT NULL,
|
||||||
|
content_list TEXT,
|
||||||
|
count INTEGER,
|
||||||
|
last_active_time REAL NOT NULL,
|
||||||
|
chat_id TEXT,
|
||||||
|
create_date REAL,
|
||||||
|
checked INTEGER,
|
||||||
|
rejected INTEGER,
|
||||||
|
modified_by TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE jargon (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
raw_content TEXT,
|
||||||
|
meaning TEXT,
|
||||||
|
chat_id TEXT,
|
||||||
|
is_global INTEGER,
|
||||||
|
count INTEGER,
|
||||||
|
is_jargon INTEGER,
|
||||||
|
last_inference_count INTEGER,
|
||||||
|
is_complete INTEGER,
|
||||||
|
inference_with_context TEXT,
|
||||||
|
inference_content_only TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO chat_streams (
|
||||||
|
id,
|
||||||
|
stream_id,
|
||||||
|
create_time,
|
||||||
|
last_active_time,
|
||||||
|
platform,
|
||||||
|
user_id,
|
||||||
|
group_id,
|
||||||
|
group_name
|
||||||
|
) VALUES (
|
||||||
|
1,
|
||||||
|
'session-1',
|
||||||
|
1710000000.0,
|
||||||
|
1710000300.0,
|
||||||
|
'qq',
|
||||||
|
'user-1',
|
||||||
|
'group-1',
|
||||||
|
'测试群'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO messages (
|
||||||
|
id,
|
||||||
|
message_id,
|
||||||
|
time,
|
||||||
|
chat_id,
|
||||||
|
chat_info_platform,
|
||||||
|
user_id,
|
||||||
|
user_nickname,
|
||||||
|
chat_info_group_id,
|
||||||
|
chat_info_group_name,
|
||||||
|
is_mentioned,
|
||||||
|
is_at,
|
||||||
|
processed_plain_text,
|
||||||
|
display_message,
|
||||||
|
is_emoji,
|
||||||
|
is_picid,
|
||||||
|
is_command,
|
||||||
|
is_notify,
|
||||||
|
additional_config,
|
||||||
|
priority_mode
|
||||||
|
) VALUES (
|
||||||
|
1,
|
||||||
|
'msg-1',
|
||||||
|
1710000010.0,
|
||||||
|
'session-1',
|
||||||
|
'qq',
|
||||||
|
'user-1',
|
||||||
|
'测试用户',
|
||||||
|
'group-1',
|
||||||
|
'测试群',
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
'你好',
|
||||||
|
'你好呀',
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
'{"source":"legacy"}',
|
||||||
|
'high'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO action_records (
|
||||||
|
id,
|
||||||
|
action_id,
|
||||||
|
time,
|
||||||
|
action_reasoning,
|
||||||
|
action_name,
|
||||||
|
action_data,
|
||||||
|
action_prompt_display,
|
||||||
|
chat_id
|
||||||
|
) VALUES (
|
||||||
|
1,
|
||||||
|
'action-1',
|
||||||
|
1710000020.0,
|
||||||
|
'需要调用工具',
|
||||||
|
'search',
|
||||||
|
'{"query":"MaiBot"}',
|
||||||
|
'执行搜索',
|
||||||
|
'session-1'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO expression (
|
||||||
|
id,
|
||||||
|
situation,
|
||||||
|
style,
|
||||||
|
content_list,
|
||||||
|
count,
|
||||||
|
last_active_time,
|
||||||
|
chat_id,
|
||||||
|
create_date,
|
||||||
|
checked,
|
||||||
|
rejected,
|
||||||
|
modified_by
|
||||||
|
) VALUES (
|
||||||
|
1,
|
||||||
|
'打招呼',
|
||||||
|
'可爱',
|
||||||
|
'["你好呀","早上好"]',
|
||||||
|
3,
|
||||||
|
1710000030.0,
|
||||||
|
'session-1',
|
||||||
|
1710000040.0,
|
||||||
|
1,
|
||||||
|
0,
|
||||||
|
'ai'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO jargon (
|
||||||
|
id,
|
||||||
|
content,
|
||||||
|
raw_content,
|
||||||
|
meaning,
|
||||||
|
chat_id,
|
||||||
|
is_global,
|
||||||
|
count,
|
||||||
|
is_jargon,
|
||||||
|
last_inference_count,
|
||||||
|
is_complete,
|
||||||
|
inference_with_context,
|
||||||
|
inference_content_only
|
||||||
|
) VALUES (
|
||||||
|
1,
|
||||||
|
'上分',
|
||||||
|
'["上分"]',
|
||||||
|
'提高排名',
|
||||||
|
'session-1',
|
||||||
|
0,
|
||||||
|
5,
|
||||||
|
1,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
'{"guess":"context"}',
|
||||||
|
'{"guess":"content"}'
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None:
|
||||||
|
"""应支持读取与写入 SQLite ``user_version``。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "version_store.db")
|
||||||
|
version_store = SQLiteUserVersionStore()
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
assert version_store.read_version(connection) == 0
|
||||||
|
version_store.write_version(connection, 7)
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
assert version_store.read_version(connection) == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None:
|
||||||
|
"""应能提取 SQLite 数据库的表与列结构。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "schema_inspector.db")
|
||||||
|
inspector = SQLiteSchemaInspector()
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
CREATE TABLE legacy_records (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
payload TEXT NOT NULL,
|
||||||
|
created_at TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
snapshot = inspector.inspect(connection)
|
||||||
|
|
||||||
|
assert snapshot.has_table("legacy_records")
|
||||||
|
assert snapshot.has_column("legacy_records", "payload")
|
||||||
|
assert not snapshot.has_column("legacy_records", "missing_column")
|
||||||
|
table_schema = snapshot.get_table("legacy_records")
|
||||||
|
|
||||||
|
assert table_schema is not None
|
||||||
|
assert table_schema.column_names() == ["created_at", "id", "payload"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_can_identify_empty_database(tmp_path: Path) -> None:
|
||||||
|
"""空数据库应被解析为版本 ``0``。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "empty_resolver.db")
|
||||||
|
resolver = SchemaVersionResolver()
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
resolved_version = resolver.resolve(connection)
|
||||||
|
|
||||||
|
assert resolved_version.version == 0
|
||||||
|
assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE
|
||||||
|
assert resolved_version.snapshot is not None
|
||||||
|
assert resolved_version.snapshot.is_empty()
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None:
|
||||||
|
"""未写入 ``user_version`` 的历史库应支持通过探测器识别版本。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db")
|
||||||
|
resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()])
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)"))
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
resolved_version = resolver.resolve(connection)
|
||||||
|
|
||||||
|
assert resolved_version.version == 2
|
||||||
|
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||||
|
assert resolved_version.detector_name == "fixed_version_detector"
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None:
|
||||||
|
"""迁移编排器应能按顺序执行已注册步骤并更新版本号。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "manager.db")
|
||||||
|
executed_steps: List[str] = []
|
||||||
|
|
||||||
|
def migrate_0_to_1(context: MigrationExecutionContext) -> None:
|
||||||
|
"""测试迁移步骤 0 -> 1。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤执行上下文。
|
||||||
|
"""
|
||||||
|
executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1")
|
||||||
|
context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"))
|
||||||
|
|
||||||
|
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||||
|
"""测试迁移步骤 1 -> 2。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤执行上下文。
|
||||||
|
"""
|
||||||
|
executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2")
|
||||||
|
context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT"))
|
||||||
|
|
||||||
|
registry = MigrationRegistry(
|
||||||
|
steps=[
|
||||||
|
MigrationStep(
|
||||||
|
version_from=0,
|
||||||
|
version_to=1,
|
||||||
|
name="create_sample_records",
|
||||||
|
description="创建示例表。",
|
||||||
|
handler=migrate_0_to_1,
|
||||||
|
),
|
||||||
|
MigrationStep(
|
||||||
|
version_from=1,
|
||||||
|
version_to=2,
|
||||||
|
name="add_sample_email",
|
||||||
|
description="为示例表增加邮箱字段。",
|
||||||
|
handler=migrate_1_to_2,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
manager = DatabaseMigrationManager(engine=engine, registry=registry)
|
||||||
|
|
||||||
|
migration_plan = manager.migrate()
|
||||||
|
|
||||||
|
assert migration_plan.step_count() == 2
|
||||||
|
assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"]
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
version_store = SQLiteUserVersionStore()
|
||||||
|
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||||
|
recorded_version = version_store.read_version(connection)
|
||||||
|
|
||||||
|
assert recorded_version == 2
|
||||||
|
assert snapshot.has_table("sample_records")
|
||||||
|
assert snapshot.has_column("sample_records", "email")
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_can_report_step_progress(tmp_path: Path) -> None:
|
||||||
|
"""迁移编排器应支持通过上下文上报步骤进度。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "manager_progress.db")
|
||||||
|
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||||
|
|
||||||
|
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||||
|
"""构建测试用进度上报器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||||
|
"""
|
||||||
|
reporter = FakeMigrationProgressReporter()
|
||||||
|
reporter_instances.append(reporter)
|
||||||
|
return reporter
|
||||||
|
|
||||||
|
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||||
|
"""测试迁移步骤 ``1 -> 2`` 的进度上报。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤执行上下文。
|
||||||
|
"""
|
||||||
|
context.start_progress(total=3, description="总迁移进度", unit_name="表")
|
||||||
|
context.advance_progress(item_name="chat_sessions")
|
||||||
|
context.advance_progress(item_name="mai_messages")
|
||||||
|
context.advance_progress(item_name="tool_records")
|
||||||
|
context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)"))
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
SQLiteUserVersionStore().write_version(connection, 1)
|
||||||
|
|
||||||
|
registry = MigrationRegistry(
|
||||||
|
steps=[
|
||||||
|
MigrationStep(
|
||||||
|
version_from=1,
|
||||||
|
version_to=2,
|
||||||
|
name="progress_step",
|
||||||
|
description="测试进度上报。",
|
||||||
|
handler=migrate_1_to_2,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
manager = DatabaseMigrationManager(
|
||||||
|
engine=engine,
|
||||||
|
registry=registry,
|
||||||
|
progress_reporter_factory=_build_reporter,
|
||||||
|
)
|
||||||
|
|
||||||
|
migration_plan = manager.migrate()
|
||||||
|
|
||||||
|
assert migration_plan.step_count() == 1
|
||||||
|
assert len(reporter_instances) == 1
|
||||||
|
assert reporter_instances[0].events == [
|
||||||
|
("open", None, None, None),
|
||||||
|
("start", 3, "总迁移进度", "表"),
|
||||||
|
("advance", 1, "chat_sessions", None),
|
||||||
|
("advance", 1, "mai_messages", None),
|
||||||
|
("advance", 1, "tool_records", None),
|
||||||
|
("close", None, None, None),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None:
|
||||||
|
"""默认解析器应能识别未写入版本号的最新结构数据库。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "latest_resolver.db")
|
||||||
|
resolver = build_default_schema_version_resolver()
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_current_schema(connection)
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
resolved_version = resolver.resolve(connection)
|
||||||
|
|
||||||
|
assert resolved_version.version == LATEST_SCHEMA_VERSION
|
||||||
|
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||||
|
assert resolved_version.detector_name == "latest_schema_detector"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None:
|
||||||
|
"""默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db")
|
||||||
|
resolver = build_default_schema_version_resolver()
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_legacy_v1_schema_with_sample_data(connection)
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
resolved_version = resolver.resolve(connection)
|
||||||
|
|
||||||
|
assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION
|
||||||
|
assert resolved_version.source == SchemaVersionSource.DETECTOR
|
||||||
|
assert resolved_version.detector_name == "legacy_v1_schema_detector"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None:
|
||||||
|
"""已是最新结构但未写版本号的数据库应直接补写 ``user_version``。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "latest_finalize.db")
|
||||||
|
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_current_schema(connection)
|
||||||
|
|
||||||
|
migration_state = bootstrapper.prepare_database()
|
||||||
|
bootstrapper.finalize_database(migration_state)
|
||||||
|
|
||||||
|
assert not migration_state.requires_migration()
|
||||||
|
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||||
|
assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||||
|
|
||||||
|
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None:
|
||||||
|
"""空库在建表完成后应回写最新 ``user_version``。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db")
|
||||||
|
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||||
|
|
||||||
|
migration_state = bootstrapper.prepare_database()
|
||||||
|
|
||||||
|
assert not migration_state.requires_migration()
|
||||||
|
assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION
|
||||||
|
assert migration_state.target_version == LATEST_SCHEMA_VERSION
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_current_schema(connection)
|
||||||
|
|
||||||
|
bootstrapper.finalize_database(migration_state)
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||||
|
|
||||||
|
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None:
|
||||||
|
"""启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db")
|
||||||
|
execution_marks: List[str] = []
|
||||||
|
|
||||||
|
def migrate_1_to_2(context: MigrationExecutionContext) -> None:
|
||||||
|
"""测试桥接器迁移步骤 ``1 -> 2``。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤执行上下文。
|
||||||
|
"""
|
||||||
|
execution_marks.append(f"step={context.step_name},index={context.step_index}")
|
||||||
|
context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT"))
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
connection.execute(
|
||||||
|
text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")
|
||||||
|
)
|
||||||
|
SQLiteUserVersionStore().write_version(connection, 1)
|
||||||
|
|
||||||
|
registry = MigrationRegistry(
|
||||||
|
steps=[
|
||||||
|
MigrationStep(
|
||||||
|
version_from=1,
|
||||||
|
version_to=2,
|
||||||
|
name="bootstrap_add_email",
|
||||||
|
description="为桥接器测试表增加邮箱字段。",
|
||||||
|
handler=migrate_1_to_2,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
bootstrapper = DatabaseMigrationBootstrapper(
|
||||||
|
manager=DatabaseMigrationManager(engine=engine, registry=registry),
|
||||||
|
latest_schema_version=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
migration_state = bootstrapper.prepare_database()
|
||||||
|
|
||||||
|
assert migration_state.resolved_version.version == 2
|
||||||
|
assert migration_state.target_version == 2
|
||||||
|
assert execution_marks == ["step=bootstrap_add_email,index=1"]
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||||
|
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||||
|
|
||||||
|
assert recorded_version == 2
|
||||||
|
assert snapshot.has_table("bootstrap_records")
|
||||||
|
assert snapshot.has_column("bootstrap_records", "email")
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None:
|
||||||
|
"""默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db")
|
||||||
|
bootstrapper = create_database_migration_bootstrapper(engine)
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_legacy_v1_schema_with_sample_data(connection)
|
||||||
|
|
||||||
|
migration_state = bootstrapper.prepare_database()
|
||||||
|
bootstrapper.finalize_database(migration_state)
|
||||||
|
|
||||||
|
assert not migration_state.requires_migration()
|
||||||
|
assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION
|
||||||
|
assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA
|
||||||
|
|
||||||
|
with engine.connect() as connection:
|
||||||
|
recorded_version = SQLiteUserVersionStore().read_version(connection)
|
||||||
|
snapshot = SQLiteSchemaInspector().inspect(connection)
|
||||||
|
message_row = connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT session_id, processed_plain_text, additional_config, raw_content
|
||||||
|
FROM mai_messages
|
||||||
|
WHERE message_id = 'msg-1'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
).mappings().one()
|
||||||
|
action_row = connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT session_id, action_name, action_display_prompt
|
||||||
|
FROM action_records
|
||||||
|
WHERE action_id = 'action-1'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
).mappings().one()
|
||||||
|
tool_row = connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT session_id, tool_name, tool_display_prompt
|
||||||
|
FROM tool_records
|
||||||
|
WHERE tool_id = 'action-1'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
).mappings().one()
|
||||||
|
expression_row = connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT session_id, content_list, modified_by
|
||||||
|
FROM expressions
|
||||||
|
WHERE id = 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
).mappings().one()
|
||||||
|
jargon_row = connection.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT session_id_dict, raw_content, inference_with_content_only
|
||||||
|
FROM jargons
|
||||||
|
WHERE id = 1
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
).mappings().one()
|
||||||
|
|
||||||
|
assert recorded_version == LATEST_SCHEMA_VERSION
|
||||||
|
assert snapshot.has_table("__legacy_v1_messages")
|
||||||
|
assert snapshot.has_table("chat_sessions")
|
||||||
|
assert snapshot.has_table("mai_messages")
|
||||||
|
assert snapshot.has_table("tool_records")
|
||||||
|
|
||||||
|
unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False)
|
||||||
|
additional_config = json.loads(message_row["additional_config"])
|
||||||
|
expression_content_list = json.loads(expression_row["content_list"])
|
||||||
|
jargon_session_id_dict = json.loads(jargon_row["session_id_dict"])
|
||||||
|
jargon_raw_content = json.loads(jargon_row["raw_content"])
|
||||||
|
|
||||||
|
assert message_row["session_id"] == "session-1"
|
||||||
|
assert message_row["processed_plain_text"] == "你好"
|
||||||
|
assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}]
|
||||||
|
assert additional_config == {"priority_mode": "high", "source": "legacy"}
|
||||||
|
assert action_row["session_id"] == "session-1"
|
||||||
|
assert action_row["action_name"] == "search"
|
||||||
|
assert action_row["action_display_prompt"] == "执行搜索"
|
||||||
|
assert tool_row["session_id"] == "session-1"
|
||||||
|
assert tool_row["tool_name"] == "search"
|
||||||
|
assert tool_row["tool_display_prompt"] == "执行搜索"
|
||||||
|
assert expression_row["session_id"] == "session-1"
|
||||||
|
assert expression_row["modified_by"] == "AI"
|
||||||
|
assert expression_content_list == ["你好呀", "早上好"]
|
||||||
|
assert jargon_session_id_dict == {"session-1": 5}
|
||||||
|
assert jargon_raw_content == ["上分"]
|
||||||
|
assert jargon_row["inference_with_content_only"] == '{"guess":"content"}'
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None:
|
||||||
|
"""旧版迁移步骤应按目标表数量推进总进度。"""
|
||||||
|
engine = _create_sqlite_engine(tmp_path / "legacy_progress.db")
|
||||||
|
reporter_instances: List[FakeMigrationProgressReporter] = []
|
||||||
|
|
||||||
|
def _build_reporter() -> BaseMigrationProgressReporter:
|
||||||
|
"""构建测试用进度上报器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseMigrationProgressReporter: 测试用进度上报器实例。
|
||||||
|
"""
|
||||||
|
reporter = FakeMigrationProgressReporter()
|
||||||
|
reporter_instances.append(reporter)
|
||||||
|
return reporter
|
||||||
|
|
||||||
|
with engine.begin() as connection:
|
||||||
|
_create_legacy_v1_schema_with_sample_data(connection)
|
||||||
|
|
||||||
|
manager = DatabaseMigrationManager(
|
||||||
|
engine=engine,
|
||||||
|
registry=build_default_migration_registry(),
|
||||||
|
resolver=build_default_schema_version_resolver(),
|
||||||
|
progress_reporter_factory=_build_reporter,
|
||||||
|
)
|
||||||
|
|
||||||
|
migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION)
|
||||||
|
|
||||||
|
assert migration_plan.step_count() == 1
|
||||||
|
assert len(reporter_instances) == 1
|
||||||
|
reporter_events = reporter_instances[0].events
|
||||||
|
|
||||||
|
assert reporter_events[0] == ("open", None, None, None)
|
||||||
|
assert reporter_events[1] == ("start", 12, "总迁移进度", "表")
|
||||||
|
assert reporter_events[-1] == ("close", None, None, None)
|
||||||
|
assert reporter_events.count(("advance", 1, "chat_sessions", None)) == 1
|
||||||
|
assert reporter_events.count(("advance", 1, "thinking_questions", None)) == 1
|
||||||
|
assert len([event for event in reporter_events if event[0] == "advance"]) == 12
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialize_database_calls_bootstrapper_before_create_all(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""数据库初始化入口应先准备迁移,再建表、补迁移并收尾。"""
|
||||||
|
call_order: List[str] = []
|
||||||
|
|
||||||
|
def _fake_prepare_database() -> DatabaseMigrationState:
|
||||||
|
"""返回测试用迁移状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMigrationState: 不包含迁移步骤的测试状态。
|
||||||
|
"""
|
||||||
|
call_order.append("prepare_database")
|
||||||
|
return DatabaseMigrationState(
|
||||||
|
resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE),
|
||||||
|
target_version=LATEST_SCHEMA_VERSION,
|
||||||
|
plan=MigrationPlan(
|
||||||
|
current_version=EMPTY_SCHEMA_VERSION,
|
||||||
|
target_version=LATEST_SCHEMA_VERSION,
|
||||||
|
steps=[],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fake_create_all(bind) -> None:
|
||||||
|
"""记录建表调用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bind: 传入的数据库绑定对象。
|
||||||
|
"""
|
||||||
|
del bind
|
||||||
|
call_order.append("create_all")
|
||||||
|
|
||||||
|
def _fake_migrate_action_records() -> None:
|
||||||
|
"""记录轻量补迁移调用。"""
|
||||||
|
call_order.append("migrate_action_records")
|
||||||
|
|
||||||
|
def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None:
|
||||||
|
"""记录迁移收尾调用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
migration_state: 当前数据库迁移状态。
|
||||||
|
"""
|
||||||
|
del migration_state
|
||||||
|
call_order.append("finalize_database")
|
||||||
|
|
||||||
|
monkeypatch.setattr(database_module, "_db_initialized", False)
|
||||||
|
monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data")
|
||||||
|
monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database)
|
||||||
|
monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database)
|
||||||
|
monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all)
|
||||||
|
monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records)
|
||||||
|
|
||||||
|
database_module.initialize_database()
|
||||||
|
|
||||||
|
assert call_order == [
|
||||||
|
"prepare_database",
|
||||||
|
"create_all",
|
||||||
|
"migrate_action_records",
|
||||||
|
"finalize_database",
|
||||||
|
]
|
||||||
@@ -1,18 +1,23 @@
|
|||||||
from rich.traceback import install
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, TYPE_CHECKING
|
from typing import ContextManager, Generator, TYPE_CHECKING
|
||||||
|
|
||||||
|
from rich.traceback import install
|
||||||
from sqlalchemy import event, text
|
from sqlalchemy import event, text
|
||||||
from sqlalchemy.engine import Engine
|
from sqlalchemy.engine import Engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
from sqlmodel import SQLModel, Session, create_engine
|
from sqlmodel import SQLModel, Session, create_engine
|
||||||
|
|
||||||
|
from src.common.database.migrations import create_database_migration_bootstrapper
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sqlite3 import Connection as SQLite3Connection
|
from sqlite3 import Connection as SQLite3Connection
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
logger = get_logger("database")
|
||||||
|
|
||||||
|
|
||||||
# 定义数据库文件路径
|
# 定义数据库文件路径
|
||||||
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve()
|
||||||
@@ -53,6 +58,7 @@ SessionLocal = sessionmaker(
|
|||||||
bind=engine,
|
bind=engine,
|
||||||
class_=Session,
|
class_=Session,
|
||||||
)
|
)
|
||||||
|
_migration_bootstrapper = create_database_migration_bootstrapper(engine)
|
||||||
|
|
||||||
_db_initialized = False
|
_db_initialized = False
|
||||||
|
|
||||||
@@ -93,14 +99,29 @@ def _migrate_action_records_to_tool_records() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def initialize_database() -> None:
|
def initialize_database() -> None:
|
||||||
|
"""初始化数据库连接、结构与启动期迁移。
|
||||||
|
|
||||||
|
当前初始化流程遵循以下顺序:
|
||||||
|
1. 确保数据库目录存在;
|
||||||
|
2. 加载 SQLModel 模型定义;
|
||||||
|
3. 执行已注册的启动期迁移;
|
||||||
|
4. 兜底执行 ``create_all`` 确保当前模型定义已建表;
|
||||||
|
5. 执行项目现有的轻量数据补迁移逻辑。
|
||||||
|
"""
|
||||||
global _db_initialized
|
global _db_initialized
|
||||||
if _db_initialized:
|
if _db_initialized:
|
||||||
return
|
return
|
||||||
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
_DB_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
import src.common.database.database_model # noqa: F401
|
import src.common.database.database_model # noqa: F401
|
||||||
|
|
||||||
|
migration_state = _migration_bootstrapper.prepare_database()
|
||||||
|
logger.info(
|
||||||
|
"数据库迁移准备完成,"
|
||||||
|
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||||
|
)
|
||||||
SQLModel.metadata.create_all(engine)
|
SQLModel.metadata.create_all(engine)
|
||||||
_migrate_action_records_to_tool_records()
|
_migrate_action_records_to_tool_records()
|
||||||
|
_migration_bootstrapper.finalize_database(migration_state)
|
||||||
_db_initialized = True
|
_db_initialized = True
|
||||||
|
|
||||||
|
|
||||||
@@ -150,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]:
|
|||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
def get_db_session_manual():
|
def get_db_session_manual() -> ContextManager[Session]:
|
||||||
"""获取数据库会话的上下文管理器 (手动提交模式)。"""
|
"""获取数据库会话的上下文管理器 (手动提交模式)。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。
|
||||||
|
"""
|
||||||
return get_db_session(auto_commit=False)
|
return get_db_session(auto_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
79
src/common/database/migrations/__init__.py
Normal file
79
src/common/database/migrations/__init__.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""数据库迁移基础设施导出模块。"""
|
||||||
|
|
||||||
|
from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper
|
||||||
|
from .builtin import (
|
||||||
|
EMPTY_SCHEMA_VERSION,
|
||||||
|
LATEST_SCHEMA_VERSION,
|
||||||
|
LEGACY_V1_SCHEMA_VERSION,
|
||||||
|
build_default_migration_registry,
|
||||||
|
build_default_schema_version_resolver,
|
||||||
|
)
|
||||||
|
from .exceptions import (
|
||||||
|
DatabaseMigrationConfigurationError,
|
||||||
|
DatabaseMigrationError,
|
||||||
|
DatabaseMigrationExecutionError,
|
||||||
|
DatabaseMigrationPlanningError,
|
||||||
|
DatabaseMigrationVersionError,
|
||||||
|
MissingMigrationStepError,
|
||||||
|
UnrecognizedDatabaseSchemaError,
|
||||||
|
UnsupportedMigrationDirectionError,
|
||||||
|
)
|
||||||
|
from .manager import DatabaseMigrationManager
|
||||||
|
from .models import (
|
||||||
|
ColumnSchema,
|
||||||
|
DatabaseMigrationState,
|
||||||
|
DatabaseSchemaSnapshot,
|
||||||
|
MigrationExecutionContext,
|
||||||
|
MigrationPlan,
|
||||||
|
MigrationStep,
|
||||||
|
ResolvedSchemaVersion,
|
||||||
|
SchemaVersionSource,
|
||||||
|
TableSchema,
|
||||||
|
)
|
||||||
|
from .planner import MigrationPlanner
|
||||||
|
from .progress import (
|
||||||
|
BaseMigrationProgressReporter,
|
||||||
|
RichMigrationProgressReporter,
|
||||||
|
create_rich_migration_progress_reporter,
|
||||||
|
)
|
||||||
|
from .registry import MigrationRegistry
|
||||||
|
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||||
|
from .schema import SQLiteSchemaInspector
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseSchemaVersionDetector",
|
||||||
|
"BaseMigrationProgressReporter",
|
||||||
|
"build_default_migration_registry",
|
||||||
|
"build_default_schema_version_resolver",
|
||||||
|
"ColumnSchema",
|
||||||
|
"create_database_migration_bootstrapper",
|
||||||
|
"create_rich_migration_progress_reporter",
|
||||||
|
"DatabaseMigrationConfigurationError",
|
||||||
|
"DatabaseMigrationError",
|
||||||
|
"DatabaseMigrationBootstrapper",
|
||||||
|
"DatabaseMigrationExecutionError",
|
||||||
|
"DatabaseMigrationManager",
|
||||||
|
"DatabaseMigrationPlanningError",
|
||||||
|
"DatabaseMigrationState",
|
||||||
|
"DatabaseMigrationVersionError",
|
||||||
|
"DatabaseSchemaSnapshot",
|
||||||
|
"EMPTY_SCHEMA_VERSION",
|
||||||
|
"LATEST_SCHEMA_VERSION",
|
||||||
|
"LEGACY_V1_SCHEMA_VERSION",
|
||||||
|
"MigrationExecutionContext",
|
||||||
|
"MigrationPlan",
|
||||||
|
"MigrationPlanner",
|
||||||
|
"MigrationRegistry",
|
||||||
|
"MigrationStep",
|
||||||
|
"MissingMigrationStepError",
|
||||||
|
"ResolvedSchemaVersion",
|
||||||
|
"RichMigrationProgressReporter",
|
||||||
|
"SchemaVersionResolver",
|
||||||
|
"SchemaVersionSource",
|
||||||
|
"SQLiteSchemaInspector",
|
||||||
|
"SQLiteUserVersionStore",
|
||||||
|
"TableSchema",
|
||||||
|
"UnrecognizedDatabaseSchemaError",
|
||||||
|
"UnsupportedMigrationDirectionError",
|
||||||
|
]
|
||||||
171
src/common/database/migrations/bootstrap.py
Normal file
171
src/common/database/migrations/bootstrap.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""数据库迁移启动桥接层。"""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from .builtin import (
|
||||||
|
LATEST_SCHEMA_VERSION,
|
||||||
|
build_default_migration_registry,
|
||||||
|
build_default_schema_version_resolver,
|
||||||
|
)
|
||||||
|
from .exceptions import DatabaseMigrationExecutionError
|
||||||
|
from .manager import DatabaseMigrationManager
|
||||||
|
from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource
|
||||||
|
from .registry import MigrationRegistry
|
||||||
|
from .resolver import SchemaVersionResolver
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
|
||||||
|
logger = get_logger("database_migration")
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationBootstrapper:
|
||||||
|
"""数据库迁移启动桥接器。
|
||||||
|
|
||||||
|
该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束:
|
||||||
|
1. 若数据库为空,则直接交给当前模型定义建出最新结构;
|
||||||
|
2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动;
|
||||||
|
3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移;
|
||||||
|
4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
manager: DatabaseMigrationManager,
|
||||||
|
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||||
|
) -> None:
|
||||||
|
"""初始化数据库迁移启动桥接器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manager: 数据库迁移编排器。
|
||||||
|
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||||
|
"""
|
||||||
|
self.manager = manager
|
||||||
|
self.latest_schema_version = latest_schema_version
|
||||||
|
|
||||||
|
def prepare_database(self) -> DatabaseMigrationState:
|
||||||
|
"""为数据库初始化阶段准备迁移状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMigrationState: 迁移准备完成后的数据库状态。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。
|
||||||
|
"""
|
||||||
|
with self.manager.engine.connect() as connection:
|
||||||
|
resolved_version = self.manager.resolver.resolve(connection)
|
||||||
|
|
||||||
|
if resolved_version.version > self.latest_schema_version:
|
||||||
|
raise DatabaseMigrationExecutionError(
|
||||||
|
"当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。"
|
||||||
|
f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE:
|
||||||
|
logger.info(
|
||||||
|
"检测到空数据库,将直接根据当前模型创建最新结构。"
|
||||||
|
f" 目标版本={self.latest_schema_version}"
|
||||||
|
)
|
||||||
|
return self._build_noop_state(
|
||||||
|
current_version=resolved_version.version,
|
||||||
|
target_version=self.latest_schema_version,
|
||||||
|
resolved_state=resolved_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
migration_state = self.manager.describe_state(target_version=self.latest_schema_version)
|
||||||
|
if not migration_state.requires_migration():
|
||||||
|
logger.info(
|
||||||
|
f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}"
|
||||||
|
)
|
||||||
|
return migration_state
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"检测到数据库需要迁移,"
|
||||||
|
f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}"
|
||||||
|
)
|
||||||
|
self.manager.migrate(target_version=self.latest_schema_version)
|
||||||
|
return self.manager.describe_state(target_version=self.latest_schema_version)
|
||||||
|
|
||||||
|
def finalize_database(self, migration_state: DatabaseMigrationState) -> None:
|
||||||
|
"""在数据库初始化末尾补写最终 schema 版本号。
|
||||||
|
|
||||||
|
该方法主要负责两类场景:
|
||||||
|
1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本;
|
||||||
|
2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
migration_state: 初始化前解析得到的迁移状态。
|
||||||
|
"""
|
||||||
|
if migration_state.requires_migration():
|
||||||
|
return
|
||||||
|
if migration_state.target_version <= 0:
|
||||||
|
return
|
||||||
|
if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA:
|
||||||
|
return
|
||||||
|
|
||||||
|
with self.manager.engine.begin() as connection:
|
||||||
|
self.manager.version_store.write_version(connection, migration_state.target_version)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"数据库 schema 版本写入完成。"
|
||||||
|
f" 来源={migration_state.resolved_version.source.value},"
|
||||||
|
f" 写入版本={migration_state.target_version}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_noop_state(
|
||||||
|
self,
|
||||||
|
current_version: int,
|
||||||
|
target_version: int,
|
||||||
|
resolved_state: ResolvedSchemaVersion,
|
||||||
|
) -> DatabaseMigrationState:
|
||||||
|
"""构建无迁移动作的数据库状态对象。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_version: 当前数据库版本号。
|
||||||
|
target_version: 当前初始化流程期望达到的目标版本号。
|
||||||
|
resolved_state: 已解析的数据库版本状态。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMigrationState: 不包含迁移步骤的状态对象。
|
||||||
|
"""
|
||||||
|
return DatabaseMigrationState(
|
||||||
|
resolved_version=resolved_state,
|
||||||
|
target_version=target_version,
|
||||||
|
plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_database_migration_bootstrapper(
|
||||||
|
engine: Engine,
|
||||||
|
registry: Optional[MigrationRegistry] = None,
|
||||||
|
resolver: Optional[SchemaVersionResolver] = None,
|
||||||
|
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||||
|
latest_schema_version: int = LATEST_SCHEMA_VERSION,
|
||||||
|
) -> DatabaseMigrationBootstrapper:
|
||||||
|
"""创建数据库迁移启动桥接器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: 目标数据库引擎。
|
||||||
|
registry: 迁移步骤注册表;未提供时使用默认注册表。
|
||||||
|
resolver: 数据库版本解析器;未提供时使用默认解析器。
|
||||||
|
version_store: 版本存储器;未提供时使用默认存储器。
|
||||||
|
latest_schema_version: 当前代码支持的最新 schema 版本号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。
|
||||||
|
"""
|
||||||
|
migration_registry = registry or build_default_migration_registry()
|
||||||
|
migration_resolver = resolver or build_default_schema_version_resolver()
|
||||||
|
migration_version_store = version_store or SQLiteUserVersionStore()
|
||||||
|
migration_manager = DatabaseMigrationManager(
|
||||||
|
engine=engine,
|
||||||
|
registry=migration_registry,
|
||||||
|
resolver=migration_resolver,
|
||||||
|
version_store=migration_version_store,
|
||||||
|
)
|
||||||
|
return DatabaseMigrationBootstrapper(
|
||||||
|
manager=migration_manager,
|
||||||
|
latest_schema_version=latest_schema_version,
|
||||||
|
)
|
||||||
159
src/common/database/migrations/builtin.py
Normal file
159
src/common/database/migrations/builtin.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
"""数据库迁移内置版本与默认注册表。"""
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2
|
||||||
|
from .models import DatabaseSchemaSnapshot, MigrationStep
|
||||||
|
from .registry import MigrationRegistry
|
||||||
|
from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
from .schema import SQLiteSchemaInspector
|
||||||
|
|
||||||
|
EMPTY_SCHEMA_VERSION = 0
|
||||||
|
LEGACY_V1_SCHEMA_VERSION = 1
|
||||||
|
LATEST_SCHEMA_VERSION = 2
|
||||||
|
|
||||||
|
_LEGACY_V1_EXCLUSIVE_TABLES = (
|
||||||
|
"chat_streams",
|
||||||
|
"emoji",
|
||||||
|
"emoji_description_cache",
|
||||||
|
"expression",
|
||||||
|
"group_info",
|
||||||
|
"image_descriptions",
|
||||||
|
"jargon",
|
||||||
|
"messages",
|
||||||
|
"thinking_back",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LatestSchemaVersionDetector(BaseSchemaVersionDetector):
|
||||||
|
"""当前最新 schema 结构探测器。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""返回探测器名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 当前探测器名称。
|
||||||
|
"""
|
||||||
|
return "latest_schema_detector"
|
||||||
|
|
||||||
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
|
"""检测数据库是否已经是当前最新结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||||
|
return None
|
||||||
|
|
||||||
|
latest_marker_tables = (
|
||||||
|
"mai_messages",
|
||||||
|
"chat_sessions",
|
||||||
|
"expressions",
|
||||||
|
"jargons",
|
||||||
|
"thinking_questions",
|
||||||
|
"tool_records",
|
||||||
|
)
|
||||||
|
if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("images", "image_hash"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("images", "full_path"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("images", "image_type"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("action_records", "session_id"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("chat_history", "session_id"):
|
||||||
|
return None
|
||||||
|
if not snapshot.has_column("person_info", "user_nickname"):
|
||||||
|
return None
|
||||||
|
return LATEST_SCHEMA_VERSION
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyV1SchemaDetector(BaseSchemaVersionDetector):
|
||||||
|
"""旧版 ``0.x`` schema 结构探测器。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""返回探测器名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 当前探测器名称。
|
||||||
|
"""
|
||||||
|
return "legacy_v1_schema_detector"
|
||||||
|
|
||||||
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
|
"""检测数据库是否为旧版 ``0.x`` 结构。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES):
|
||||||
|
return LEGACY_V1_SCHEMA_VERSION
|
||||||
|
|
||||||
|
legacy_shared_markers = (
|
||||||
|
("action_records", ("chat_id", "time")),
|
||||||
|
("chat_history", ("chat_id", "original_text")),
|
||||||
|
("images", ("emoji_hash", "path", "type")),
|
||||||
|
("llm_usage", ("model_api_provider", "status")),
|
||||||
|
("online_time", ("duration",)),
|
||||||
|
("person_info", ("nickname", "group_nick_name")),
|
||||||
|
)
|
||||||
|
for table_name, required_columns in legacy_shared_markers:
|
||||||
|
if snapshot.has_table(table_name) and all(
|
||||||
|
snapshot.has_column(table_name, column_name) for column_name in required_columns
|
||||||
|
):
|
||||||
|
return LEGACY_V1_SCHEMA_VERSION
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]:
|
||||||
|
"""构建默认 schema 版本探测器链。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
LatestSchemaVersionDetector(),
|
||||||
|
LegacyV1SchemaDetector(),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_schema_version_resolver() -> SchemaVersionResolver:
|
||||||
|
"""构建默认 schema 版本解析器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SchemaVersionResolver: 配置完成的 schema 版本解析器。
|
||||||
|
"""
|
||||||
|
return SchemaVersionResolver(
|
||||||
|
version_store=SQLiteUserVersionStore(),
|
||||||
|
schema_inspector=SQLiteSchemaInspector(),
|
||||||
|
detectors=build_default_schema_version_detectors(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_migration_registry() -> MigrationRegistry:
|
||||||
|
"""构建默认迁移步骤注册表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationRegistry: 含默认迁移步骤的注册表实例。
|
||||||
|
"""
|
||||||
|
return MigrationRegistry(
|
||||||
|
steps=[
|
||||||
|
MigrationStep(
|
||||||
|
version_from=LEGACY_V1_SCHEMA_VERSION,
|
||||||
|
version_to=LATEST_SCHEMA_VERSION,
|
||||||
|
name="legacy_v1_to_latest_v2",
|
||||||
|
description="将旧版 0.x 数据库整体迁移到当前最新 schema。",
|
||||||
|
handler=migrate_legacy_v1_to_v2,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
33
src/common/database/migrations/exceptions.py
Normal file
33
src/common/database/migrations/exceptions.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
"""数据库迁移基础设施异常定义。"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationError(Exception):
|
||||||
|
"""数据库迁移基础异常。"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationConfigurationError(DatabaseMigrationError):
|
||||||
|
"""数据库迁移配置不合法。"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationPlanningError(DatabaseMigrationError):
|
||||||
|
"""数据库迁移计划生成失败。"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationExecutionError(DatabaseMigrationError):
|
||||||
|
"""数据库迁移执行失败。"""
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationVersionError(DatabaseMigrationError):
|
||||||
|
"""数据库版本读写或校验失败。"""
|
||||||
|
|
||||||
|
|
||||||
|
class MissingMigrationStepError(DatabaseMigrationPlanningError):
|
||||||
|
"""缺少某个版本区间所需的迁移步骤。"""
|
||||||
|
|
||||||
|
|
||||||
|
class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError):
|
||||||
|
"""当前迁移方向不被支持。"""
|
||||||
|
|
||||||
|
|
||||||
|
class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError):
|
||||||
|
"""无法识别未标记版本数据库的结构。"""
|
||||||
1384
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
1384
src/common/database/migrations/legacy_v1_to_v2.py
Normal file
File diff suppressed because it is too large
Load Diff
205
src/common/database/migrations/manager.py
Normal file
205
src/common/database/migrations/manager.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""数据库迁移编排器。"""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Connection, Engine
|
||||||
|
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from .exceptions import DatabaseMigrationExecutionError
|
||||||
|
from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan
|
||||||
|
from .planner import MigrationPlanner
|
||||||
|
from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter
|
||||||
|
from .registry import MigrationRegistry
|
||||||
|
from .resolver import SchemaVersionResolver
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
|
||||||
|
logger = get_logger("database_migration")
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMigrationManager:
|
||||||
|
"""数据库迁移编排器。
|
||||||
|
|
||||||
|
该类只负责基础设施层面的编排工作,包括:
|
||||||
|
1. 解析当前数据库版本;
|
||||||
|
2. 生成迁移计划;
|
||||||
|
3. 顺序执行已注册迁移步骤;
|
||||||
|
4. 在每一步成功后更新 ``user_version``。
|
||||||
|
|
||||||
|
当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
engine: Engine,
|
||||||
|
registry: Optional[MigrationRegistry] = None,
|
||||||
|
planner: Optional[MigrationPlanner] = None,
|
||||||
|
resolver: Optional[SchemaVersionResolver] = None,
|
||||||
|
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||||
|
progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""初始化数据库迁移编排器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
engine: 目标数据库引擎。
|
||||||
|
registry: 迁移步骤注册表。
|
||||||
|
planner: 迁移计划生成器。
|
||||||
|
resolver: 数据库版本解析器。
|
||||||
|
version_store: 版本存储器。
|
||||||
|
progress_reporter_factory: 迁移进度上报器工厂。
|
||||||
|
"""
|
||||||
|
self.engine = engine
|
||||||
|
self.registry = registry or MigrationRegistry()
|
||||||
|
self.planner = planner or MigrationPlanner()
|
||||||
|
self.resolver = resolver or SchemaVersionResolver()
|
||||||
|
self.version_store = version_store or SQLiteUserVersionStore()
|
||||||
|
self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter
|
||||||
|
|
||||||
|
def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState:
|
||||||
|
"""描述当前数据库的迁移状态。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseMigrationState: 当前数据库迁移状态。
|
||||||
|
"""
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
resolved_version = self.resolver.resolve(connection)
|
||||||
|
|
||||||
|
effective_target_version = self._resolve_target_version(target_version)
|
||||||
|
migration_plan = self.planner.plan(
|
||||||
|
current_version=resolved_version.version,
|
||||||
|
target_version=effective_target_version,
|
||||||
|
registry=self.registry,
|
||||||
|
)
|
||||||
|
return DatabaseMigrationState(
|
||||||
|
resolved_version=resolved_version,
|
||||||
|
target_version=effective_target_version,
|
||||||
|
plan=migration_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
def plan(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||||
|
"""生成当前数据库的迁移计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationPlan: 当前数据库对应的迁移计划。
|
||||||
|
"""
|
||||||
|
return self.describe_state(target_version=target_version).plan
|
||||||
|
|
||||||
|
def migrate(self, target_version: Optional[int] = None) -> MigrationPlan:
|
||||||
|
"""执行迁移计划。
|
||||||
|
|
||||||
|
注意:
|
||||||
|
若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该
|
||||||
|
版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式
|
||||||
|
副作用。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationPlan: 已执行的迁移计划。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。
|
||||||
|
"""
|
||||||
|
migration_state = self.describe_state(target_version=target_version)
|
||||||
|
migration_plan = migration_state.plan
|
||||||
|
if migration_plan.is_empty():
|
||||||
|
logger.info("数据库迁移计划为空,跳过执行。")
|
||||||
|
return migration_plan
|
||||||
|
|
||||||
|
current_version = migration_state.resolved_version.version
|
||||||
|
total_steps = migration_plan.step_count()
|
||||||
|
for step_index, step in enumerate(migration_plan.steps, start=1):
|
||||||
|
logger.info(
|
||||||
|
f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with self.progress_reporter_factory() as progress_reporter:
|
||||||
|
if step.transactional:
|
||||||
|
with self.engine.begin() as connection:
|
||||||
|
execution_context = self._build_execution_context(
|
||||||
|
connection=connection,
|
||||||
|
current_version=current_version,
|
||||||
|
migration_plan=migration_plan,
|
||||||
|
step_index=step_index,
|
||||||
|
step_name=step.name,
|
||||||
|
total_steps=total_steps,
|
||||||
|
progress_reporter=progress_reporter,
|
||||||
|
)
|
||||||
|
step.run(execution_context)
|
||||||
|
self.version_store.write_version(connection, step.version_to)
|
||||||
|
else:
|
||||||
|
with self.engine.connect() as connection:
|
||||||
|
execution_context = self._build_execution_context(
|
||||||
|
connection=connection,
|
||||||
|
current_version=current_version,
|
||||||
|
migration_plan=migration_plan,
|
||||||
|
step_index=step_index,
|
||||||
|
step_name=step.name,
|
||||||
|
total_steps=total_steps,
|
||||||
|
progress_reporter=progress_reporter,
|
||||||
|
)
|
||||||
|
step.run(execution_context)
|
||||||
|
self.version_store.write_version(connection, step.version_to)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
raise DatabaseMigrationExecutionError(
|
||||||
|
f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。"
|
||||||
|
) from exc
|
||||||
|
current_version = step.version_to
|
||||||
|
logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}")
|
||||||
|
|
||||||
|
return migration_plan
|
||||||
|
|
||||||
|
def _resolve_target_version(self, target_version: Optional[int]) -> int:
|
||||||
|
"""解析最终目标版本号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_version: 调用方显式指定的目标版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 最终用于规划和执行的目标版本号。
|
||||||
|
"""
|
||||||
|
if target_version is not None:
|
||||||
|
return target_version
|
||||||
|
return self.registry.latest_version()
|
||||||
|
|
||||||
|
def _build_execution_context(
|
||||||
|
self,
|
||||||
|
connection: Connection,
|
||||||
|
current_version: int,
|
||||||
|
migration_plan: MigrationPlan,
|
||||||
|
step_index: int,
|
||||||
|
step_name: str,
|
||||||
|
total_steps: int,
|
||||||
|
progress_reporter: BaseMigrationProgressReporter,
|
||||||
|
) -> MigrationExecutionContext:
|
||||||
|
"""构建单个迁移步骤的执行上下文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前迁移步骤使用的数据库连接。
|
||||||
|
current_version: 当前数据库版本。
|
||||||
|
migration_plan: 当前迁移计划。
|
||||||
|
step_index: 当前步骤序号,从 ``1`` 开始。
|
||||||
|
step_name: 当前步骤名称。
|
||||||
|
total_steps: 计划总步骤数。
|
||||||
|
progress_reporter: 当前步骤使用的进度上报器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationExecutionContext: 当前步骤的执行上下文对象。
|
||||||
|
"""
|
||||||
|
return MigrationExecutionContext(
|
||||||
|
connection=connection,
|
||||||
|
current_version=current_version,
|
||||||
|
target_version=migration_plan.target_version,
|
||||||
|
step_index=step_index,
|
||||||
|
step_name=step_name,
|
||||||
|
total_steps=total_steps,
|
||||||
|
progress_reporter=progress_reporter,
|
||||||
|
)
|
||||||
285
src/common/database/migrations/models.py
Normal file
285
src/common/database/migrations/models.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""数据库迁移基础设施核心数据模型。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .progress import BaseMigrationProgressReporter
|
||||||
|
|
||||||
|
|
||||||
|
def _utc_now() -> datetime:
|
||||||
|
"""返回当前 UTC 时间。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
datetime: 当前 UTC 时间。
|
||||||
|
"""
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaVersionSource(str, Enum):
|
||||||
|
"""数据库版本来源。"""
|
||||||
|
|
||||||
|
PRAGMA = "pragma"
|
||||||
|
DETECTOR = "detector"
|
||||||
|
EMPTY_DATABASE = "empty_database"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ColumnSchema:
|
||||||
|
"""数据库列结构快照。"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
declared_type: str
|
||||||
|
default_value: Optional[str]
|
||||||
|
is_not_null: bool
|
||||||
|
primary_key_position: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TableSchema:
|
||||||
|
"""数据库表结构快照。"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
columns: Dict[str, ColumnSchema]
|
||||||
|
|
||||||
|
def has_column(self, column_name: str) -> bool:
|
||||||
|
"""判断表中是否存在指定列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: 待检查的列名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若列存在则返回 ``True``,否则返回 ``False``。
|
||||||
|
"""
|
||||||
|
return column_name in self.columns
|
||||||
|
|
||||||
|
def get_column(self, column_name: str) -> Optional[ColumnSchema]:
|
||||||
|
"""获取指定列的结构信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: 待获取的列名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
return self.columns.get(column_name)
|
||||||
|
|
||||||
|
def column_names(self) -> List[str]:
|
||||||
|
"""返回当前表中全部列名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 按字母顺序排列的列名列表。
|
||||||
|
"""
|
||||||
|
return sorted(self.columns)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatabaseSchemaSnapshot:
|
||||||
|
"""数据库结构快照。"""
|
||||||
|
|
||||||
|
tables: Dict[str, TableSchema]
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""判断数据库是否没有任何用户表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若数据库中没有用户表则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return not self.tables
|
||||||
|
|
||||||
|
def has_table(self, table_name: str) -> bool:
|
||||||
|
"""判断数据库是否存在指定表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 待检查的表名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若表存在则返回 ``True``,否则返回 ``False``。
|
||||||
|
"""
|
||||||
|
return table_name in self.tables
|
||||||
|
|
||||||
|
def has_column(self, table_name: str, column_name: str) -> bool:
|
||||||
|
"""判断数据库指定表中是否存在指定列。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 待检查的表名。
|
||||||
|
column_name: 待检查的列名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若表和列均存在则返回 ``True``。
|
||||||
|
"""
|
||||||
|
table_schema = self.get_table(table_name)
|
||||||
|
if table_schema is None:
|
||||||
|
return False
|
||||||
|
return table_schema.has_column(column_name)
|
||||||
|
|
||||||
|
def get_table(self, table_name: str) -> Optional[TableSchema]:
|
||||||
|
"""获取指定表的结构信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: 待获取的表名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
return self.tables.get(table_name)
|
||||||
|
|
||||||
|
def table_names(self) -> List[str]:
|
||||||
|
"""返回当前数据库中的全部用户表名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 按字母顺序排列的表名列表。
|
||||||
|
"""
|
||||||
|
return sorted(self.tables)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ResolvedSchemaVersion:
|
||||||
|
"""解析后的数据库版本信息。"""
|
||||||
|
|
||||||
|
version: int
|
||||||
|
source: SchemaVersionSource
|
||||||
|
detector_name: Optional[str] = None
|
||||||
|
snapshot: Optional[DatabaseSchemaSnapshot] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MigrationExecutionContext:
|
||||||
|
"""单个迁移步骤的执行上下文。"""
|
||||||
|
|
||||||
|
connection: Connection
|
||||||
|
current_version: int
|
||||||
|
target_version: int
|
||||||
|
step_index: int
|
||||||
|
step_name: str
|
||||||
|
total_steps: int
|
||||||
|
started_at: datetime = field(default_factory=_utc_now)
|
||||||
|
progress_reporter: Optional["BaseMigrationProgressReporter"] = None
|
||||||
|
|
||||||
|
def is_last_step(self) -> bool:
|
||||||
|
"""判断当前步骤是否为最后一步。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若当前步骤已是计划中的最后一步则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return self.step_index >= self.total_steps
|
||||||
|
|
||||||
|
def start_progress(
|
||||||
|
self,
|
||||||
|
total: int,
|
||||||
|
description: str = "总迁移进度",
|
||||||
|
unit_name: str = "表",
|
||||||
|
) -> None:
|
||||||
|
"""启动当前迁移步骤的进度展示。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: 当前步骤需要处理的总项目数。
|
||||||
|
description: 进度描述文本。
|
||||||
|
unit_name: 进度单位名称。
|
||||||
|
"""
|
||||||
|
if self.progress_reporter is None:
|
||||||
|
return
|
||||||
|
self.progress_reporter.start(total=total, description=description, unit_name=unit_name)
|
||||||
|
|
||||||
|
def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||||
|
"""推进当前迁移步骤的进度展示。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
advance: 本次推进的步数。
|
||||||
|
item_name: 当前完成的项目名称。
|
||||||
|
"""
|
||||||
|
if self.progress_reporter is None:
|
||||||
|
return
|
||||||
|
self.progress_reporter.advance(advance=advance, item_name=item_name)
|
||||||
|
|
||||||
|
|
||||||
|
MigrationHandler = Callable[[MigrationExecutionContext], None]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MigrationStep:
|
||||||
|
"""单个数据库迁移步骤定义。"""
|
||||||
|
|
||||||
|
version_from: int
|
||||||
|
version_to: int
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
handler: MigrationHandler
|
||||||
|
transactional: bool = True
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""校验迁移步骤定义是否合法。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 当版本号不合法或迁移方向错误时抛出。
|
||||||
|
"""
|
||||||
|
if self.version_from < 0:
|
||||||
|
raise ValueError("迁移起始版本不能小于 0。")
|
||||||
|
if self.version_to <= self.version_from:
|
||||||
|
raise ValueError("迁移目标版本必须大于起始版本。")
|
||||||
|
|
||||||
|
def run(self, context: MigrationExecutionContext) -> None:
|
||||||
|
"""执行当前迁移步骤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context: 当前迁移步骤的执行上下文。
|
||||||
|
"""
|
||||||
|
self.handler(context)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MigrationPlan:
|
||||||
|
"""数据库迁移执行计划。"""
|
||||||
|
|
||||||
|
current_version: int
|
||||||
|
target_version: int
|
||||||
|
steps: List[MigrationStep]
|
||||||
|
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""判断迁移计划是否为空。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若无需执行任何迁移步骤则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return not self.steps
|
||||||
|
|
||||||
|
def step_count(self) -> int:
|
||||||
|
"""返回迁移计划中的步骤数量。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 当前计划中的迁移步骤数。
|
||||||
|
"""
|
||||||
|
return len(self.steps)
|
||||||
|
|
||||||
|
def latest_reachable_version(self) -> int:
|
||||||
|
"""返回该计划执行后的最终版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。
|
||||||
|
"""
|
||||||
|
if self.is_empty():
|
||||||
|
return self.current_version
|
||||||
|
return self.steps[-1].version_to
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatabaseMigrationState:
|
||||||
|
"""数据库迁移状态描述。"""
|
||||||
|
|
||||||
|
resolved_version: ResolvedSchemaVersion
|
||||||
|
target_version: int
|
||||||
|
plan: MigrationPlan
|
||||||
|
|
||||||
|
def requires_migration(self) -> bool:
|
||||||
|
"""判断当前状态是否需要执行迁移。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若计划中存在待执行迁移步骤则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return not self.plan.is_empty()
|
||||||
108
src/common/database/migrations/planner.py
Normal file
108
src/common/database/migrations/planner.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""数据库迁移计划生成器。"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from .exceptions import (
|
||||||
|
DatabaseMigrationPlanningError,
|
||||||
|
MissingMigrationStepError,
|
||||||
|
UnsupportedMigrationDirectionError,
|
||||||
|
)
|
||||||
|
from .models import MigrationPlan, MigrationStep
|
||||||
|
from .registry import MigrationRegistry
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationPlanner:
|
||||||
|
"""数据库迁移计划生成器。"""
|
||||||
|
|
||||||
|
def plan(
|
||||||
|
self,
|
||||||
|
current_version: int,
|
||||||
|
target_version: int,
|
||||||
|
registry: MigrationRegistry,
|
||||||
|
) -> MigrationPlan:
|
||||||
|
"""根据当前版本与目标版本生成迁移计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_version: 当前数据库版本。
|
||||||
|
target_version: 目标数据库版本。
|
||||||
|
registry: 迁移步骤注册表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationPlan: 按顺序执行的迁移计划。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||||
|
MissingMigrationStepError: 当所需迁移步骤缺失时抛出。
|
||||||
|
UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。
|
||||||
|
"""
|
||||||
|
self._validate_version(current_version, "current_version")
|
||||||
|
self._validate_version(target_version, "target_version")
|
||||||
|
|
||||||
|
if target_version < current_version:
|
||||||
|
raise UnsupportedMigrationDirectionError(
|
||||||
|
f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}。"
|
||||||
|
)
|
||||||
|
if target_version == current_version:
|
||||||
|
return MigrationPlan(current_version=current_version, target_version=target_version, steps=[])
|
||||||
|
|
||||||
|
steps = self._build_steps(current_version, target_version, registry)
|
||||||
|
return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps)
|
||||||
|
|
||||||
|
def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan:
|
||||||
|
"""生成迁移到注册表最新版本的执行计划。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_version: 当前数据库版本。
|
||||||
|
registry: 迁移步骤注册表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MigrationPlan: 指向最新版本的迁移计划。
|
||||||
|
"""
|
||||||
|
target_version = registry.latest_version()
|
||||||
|
return self.plan(current_version=current_version, target_version=target_version, registry=registry)
|
||||||
|
|
||||||
|
def _build_steps(
|
||||||
|
self,
|
||||||
|
current_version: int,
|
||||||
|
target_version: int,
|
||||||
|
registry: MigrationRegistry,
|
||||||
|
) -> List[MigrationStep]:
|
||||||
|
"""按顺序拼装迁移步骤链。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_version: 当前数据库版本。
|
||||||
|
target_version: 目标数据库版本。
|
||||||
|
registry: 迁移步骤注册表。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[MigrationStep]: 按顺序执行的迁移步骤列表。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。
|
||||||
|
"""
|
||||||
|
planned_steps: List[MigrationStep] = []
|
||||||
|
next_version = current_version
|
||||||
|
|
||||||
|
while next_version < target_version:
|
||||||
|
step = registry.get_step(next_version)
|
||||||
|
if step is None:
|
||||||
|
raise MissingMigrationStepError(
|
||||||
|
f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。"
|
||||||
|
)
|
||||||
|
planned_steps.append(step)
|
||||||
|
next_version = step.version_to
|
||||||
|
|
||||||
|
return planned_steps
|
||||||
|
|
||||||
|
def _validate_version(self, version: int, field_name: str) -> None:
|
||||||
|
"""校验版本号是否合法。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: 待校验的版本号。
|
||||||
|
field_name: 当前版本号对应的字段名。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationPlanningError: 当版本号非法时抛出。
|
||||||
|
"""
|
||||||
|
if version < 0:
|
||||||
|
raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}")
|
||||||
272
src/common/database/migrations/progress.py
Normal file
272
src/common/database/migrations/progress.py
Normal file
@@ -0,0 +1,272 @@
|
|||||||
|
"""数据库迁移进度展示工具。"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
|
||||||
|
def _format_duration(total_seconds: Optional[float]) -> str:
|
||||||
|
"""将秒数格式化为适合展示的耗时文本。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total_seconds: 总秒数;为空时表示暂不可用。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 格式化后的耗时文本。
|
||||||
|
"""
|
||||||
|
if total_seconds is None:
|
||||||
|
return "--:--:--"
|
||||||
|
safe_seconds = max(total_seconds, 0.0)
|
||||||
|
return str(timedelta(seconds=int(safe_seconds)))
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationSummaryColumn(ProgressColumn):
|
||||||
|
"""渲染数据库迁移总进度摘要列。"""
|
||||||
|
|
||||||
|
def render(self, task: Task) -> Text:
|
||||||
|
"""渲染当前任务的总进度摘要。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 当前进度任务对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text: 渲染后的摘要文本。
|
||||||
|
"""
|
||||||
|
display_total = task.fields.get("display_total", task.total)
|
||||||
|
total_text = "?" if display_total is None else str(int(display_total))
|
||||||
|
completed_text = str(int(task.completed))
|
||||||
|
return Text(f"总迁移进度({completed_text}/{total_text})")
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationSpeedColumn(ProgressColumn):
|
||||||
|
"""渲染数据库迁移速度列。"""
|
||||||
|
|
||||||
|
def render(self, task: Task) -> Text:
|
||||||
|
"""渲染当前任务的速度信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 当前进度任务对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text: 渲染后的速度文本。
|
||||||
|
"""
|
||||||
|
unit_name = str(task.fields.get("unit_name", "项"))
|
||||||
|
if task.speed is None or task.speed <= 0:
|
||||||
|
return Text(f"-- {unit_name}/s")
|
||||||
|
return Text(f"{task.speed:.2f} {unit_name}/s")
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationElapsedColumn(ProgressColumn):
|
||||||
|
"""渲染数据库迁移已用时间列。"""
|
||||||
|
|
||||||
|
def render(self, task: Task) -> Text:
|
||||||
|
"""渲染当前任务的已用时间。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 当前进度任务对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text: 渲染后的已用时间文本。
|
||||||
|
"""
|
||||||
|
return Text(f"已用时间 {_format_duration(task.elapsed)}")
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationRemainingColumn(ProgressColumn):
|
||||||
|
"""渲染数据库迁移预估剩余时间列。"""
|
||||||
|
|
||||||
|
def render(self, task: Task) -> Text:
|
||||||
|
"""渲染当前任务的预估剩余时间。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: 当前进度任务对象。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Text: 渲染后的预估剩余时间文本。
|
||||||
|
"""
|
||||||
|
return Text(f"预估时间 {_format_duration(task.time_remaining)}")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseMigrationProgressReporter(ABC):
|
||||||
|
"""数据库迁移进度上报器基类。"""
|
||||||
|
|
||||||
|
def __enter__(self) -> "BaseMigrationProgressReporter":
|
||||||
|
"""进入进度上报上下文。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseMigrationProgressReporter: 当前上报器实例。
|
||||||
|
"""
|
||||||
|
self.open()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
||||||
|
"""退出进度上报上下文。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exc_type: 异常类型。
|
||||||
|
exc_value: 异常实例。
|
||||||
|
traceback: 异常追踪对象。
|
||||||
|
"""
|
||||||
|
del exc_type, exc_value, traceback
|
||||||
|
self.close()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def open(self) -> None:
|
||||||
|
"""打开进度上报资源。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭进度上报资源。"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
total: int,
|
||||||
|
description: str = "总迁移进度",
|
||||||
|
unit_name: str = "表",
|
||||||
|
) -> None:
|
||||||
|
"""启动一个新的迁移进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: 任务总数。
|
||||||
|
description: 任务描述。
|
||||||
|
unit_name: 进度单位名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||||
|
"""推进当前迁移进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
advance: 本次推进的步数。
|
||||||
|
item_name: 当前完成的项目名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class NullMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||||
|
"""不输出任何内容的空进度上报器。"""
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭空进度上报器。"""
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
total: int,
|
||||||
|
description: str = "总迁移进度",
|
||||||
|
unit_name: str = "表",
|
||||||
|
) -> None:
|
||||||
|
"""启动空进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: 任务总数。
|
||||||
|
description: 任务描述。
|
||||||
|
unit_name: 进度单位名称。
|
||||||
|
"""
|
||||||
|
del total, description, unit_name
|
||||||
|
|
||||||
|
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||||
|
"""推进空进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
advance: 本次推进的步数。
|
||||||
|
item_name: 当前完成的项目名称。
|
||||||
|
"""
|
||||||
|
del advance, item_name
|
||||||
|
|
||||||
|
|
||||||
|
class RichMigrationProgressReporter(BaseMigrationProgressReporter):
|
||||||
|
"""基于 ``rich`` 的数据库迁移进度上报器。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
console: Optional[Console] = None,
|
||||||
|
disable: Optional[bool] = None,
|
||||||
|
refresh_per_second: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""初始化 ``rich`` 迁移进度上报器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
console: 输出使用的 ``rich`` 控制台。
|
||||||
|
disable: 是否禁用进度条;为空时根据终端能力自动判断。
|
||||||
|
refresh_per_second: 每秒刷新次数。
|
||||||
|
"""
|
||||||
|
self.console = console or Console()
|
||||||
|
self.disable = disable
|
||||||
|
self.refresh_per_second = refresh_per_second
|
||||||
|
self._progress: Optional[Progress] = None
|
||||||
|
self._task_id: Optional[TaskID] = None
|
||||||
|
|
||||||
|
def open(self) -> None:
|
||||||
|
"""打开 ``rich`` 进度条资源。"""
|
||||||
|
effective_disable = not self.console.is_terminal if self.disable is None else self.disable
|
||||||
|
self._progress = Progress(
|
||||||
|
MigrationSummaryColumn(),
|
||||||
|
BarColumn(),
|
||||||
|
MigrationSpeedColumn(),
|
||||||
|
MigrationElapsedColumn(),
|
||||||
|
MigrationRemainingColumn(),
|
||||||
|
console=self.console,
|
||||||
|
transient=False,
|
||||||
|
disable=effective_disable,
|
||||||
|
refresh_per_second=self.refresh_per_second,
|
||||||
|
expand=True,
|
||||||
|
)
|
||||||
|
self._progress.start()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""关闭 ``rich`` 进度条资源。"""
|
||||||
|
if self._progress is None:
|
||||||
|
return
|
||||||
|
self._progress.stop()
|
||||||
|
self._progress = None
|
||||||
|
self._task_id = None
|
||||||
|
|
||||||
|
def start(
|
||||||
|
self,
|
||||||
|
total: int,
|
||||||
|
description: str = "总迁移进度",
|
||||||
|
unit_name: str = "表",
|
||||||
|
) -> None:
|
||||||
|
"""启动一个新的 ``rich`` 迁移进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
total: 任务总数。
|
||||||
|
description: 任务描述。
|
||||||
|
unit_name: 进度单位名称。
|
||||||
|
"""
|
||||||
|
if self._progress is None:
|
||||||
|
self.open()
|
||||||
|
assert self._progress is not None
|
||||||
|
effective_total = max(total, 1)
|
||||||
|
self._task_id = self._progress.add_task(
|
||||||
|
description,
|
||||||
|
total=effective_total,
|
||||||
|
display_total=total,
|
||||||
|
unit_name=unit_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None:
|
||||||
|
"""推进当前 ``rich`` 迁移进度任务。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
advance: 本次推进的步数。
|
||||||
|
item_name: 当前完成的项目名称。
|
||||||
|
"""
|
||||||
|
del item_name
|
||||||
|
if self._progress is None or self._task_id is None:
|
||||||
|
return
|
||||||
|
self._progress.update(self._task_id, advance=advance)
|
||||||
|
|
||||||
|
|
||||||
|
def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter:
|
||||||
|
"""创建默认的 ``rich`` 迁移进度上报器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseMigrationProgressReporter: 默认迁移进度上报器实例。
|
||||||
|
"""
|
||||||
|
return RichMigrationProgressReporter()
|
||||||
98
src/common/database/migrations/registry.py
Normal file
98
src/common/database/migrations/registry.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""数据库迁移步骤注册表。"""
|
||||||
|
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from .exceptions import DatabaseMigrationConfigurationError
|
||||||
|
from .models import MigrationStep
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationRegistry:
|
||||||
|
"""数据库迁移步骤注册表。"""
|
||||||
|
|
||||||
|
def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None:
|
||||||
|
"""初始化迁移步骤注册表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps: 初始化时要注册的迁移步骤列表。
|
||||||
|
"""
|
||||||
|
self._steps_by_from_version: Dict[int, MigrationStep] = {}
|
||||||
|
if steps:
|
||||||
|
self.register_many(steps)
|
||||||
|
|
||||||
|
def register(self, step: MigrationStep) -> None:
|
||||||
|
"""注册单个迁移步骤。
|
||||||
|
|
||||||
|
当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、
|
||||||
|
易于回放,也便于后续生产问题排查。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
step: 待注册的迁移步骤定义。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。
|
||||||
|
"""
|
||||||
|
if step.version_to != step.version_from + 1:
|
||||||
|
raise DatabaseMigrationConfigurationError(
|
||||||
|
"迁移步骤必须使用相邻版本号定义,例如 2 -> 3。"
|
||||||
|
)
|
||||||
|
if step.version_from in self._steps_by_from_version:
|
||||||
|
existing_step = self._steps_by_from_version[step.version_from]
|
||||||
|
raise DatabaseMigrationConfigurationError(
|
||||||
|
f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}"
|
||||||
|
)
|
||||||
|
for registered_step in self._steps_by_from_version.values():
|
||||||
|
if registered_step.version_to == step.version_to:
|
||||||
|
raise DatabaseMigrationConfigurationError(
|
||||||
|
f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。"
|
||||||
|
)
|
||||||
|
self._steps_by_from_version[step.version_from] = step
|
||||||
|
|
||||||
|
def register_many(self, steps: List[MigrationStep]) -> None:
|
||||||
|
"""批量注册多个迁移步骤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
steps: 待注册的迁移步骤列表。
|
||||||
|
"""
|
||||||
|
for step in steps:
|
||||||
|
self.register(step)
|
||||||
|
|
||||||
|
def get_step(self, version_from: int) -> Optional[MigrationStep]:
|
||||||
|
"""获取指定起始版本的迁移步骤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_from: 迁移步骤的起始版本号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
return self._steps_by_from_version.get(version_from)
|
||||||
|
|
||||||
|
def has_step(self, version_from: int) -> bool:
|
||||||
|
"""判断指定起始版本是否已注册迁移步骤。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_from: 待检查的起始版本号。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若已注册对应步骤则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return version_from in self._steps_by_from_version
|
||||||
|
|
||||||
|
def latest_version(self) -> int:
|
||||||
|
"""返回当前注册表支持到的最新 schema 版本。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。
|
||||||
|
"""
|
||||||
|
if not self._steps_by_from_version:
|
||||||
|
return 0
|
||||||
|
return max(step.version_to for step in self._steps_by_from_version.values())
|
||||||
|
|
||||||
|
def list_steps(self) -> List[MigrationStep]:
|
||||||
|
"""按起始版本顺序返回全部迁移步骤。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[MigrationStep]: 已注册迁移步骤列表。
|
||||||
|
"""
|
||||||
|
ordered_versions = sorted(self._steps_by_from_version)
|
||||||
|
return [self._steps_by_from_version[version] for version in ordered_versions]
|
||||||
135
src/common/database/migrations/resolver.py
Normal file
135
src/common/database/migrations/resolver.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""数据库版本解析器。"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError
|
||||||
|
from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource
|
||||||
|
from .schema import SQLiteSchemaInspector
|
||||||
|
from .version_store import SQLiteUserVersionStore
|
||||||
|
|
||||||
|
|
||||||
|
class BaseSchemaVersionDetector(ABC):
|
||||||
|
"""未标记版本数据库的结构探测器基类。"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""返回当前探测器名称。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 当前探测器名称。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]:
|
||||||
|
"""根据数据库结构快照推断版本号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaVersionResolver:
|
||||||
|
"""数据库版本解析器。"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
version_store: Optional[SQLiteUserVersionStore] = None,
|
||||||
|
schema_inspector: Optional[SQLiteSchemaInspector] = None,
|
||||||
|
detectors: Optional[List[BaseSchemaVersionDetector]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""初始化数据库版本解析器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version_store: 版本存储器;未提供时将使用默认实现。
|
||||||
|
schema_inspector: 结构探测器;未提供时将使用默认实现。
|
||||||
|
detectors: 未标记版本数据库的探测器列表。
|
||||||
|
"""
|
||||||
|
self.version_store = version_store or SQLiteUserVersionStore()
|
||||||
|
self.schema_inspector = schema_inspector or SQLiteSchemaInspector()
|
||||||
|
self.detectors: List[BaseSchemaVersionDetector] = list(detectors or [])
|
||||||
|
|
||||||
|
def add_detector(self, detector: BaseSchemaVersionDetector) -> None:
|
||||||
|
"""注册一个未标记版本数据库探测器。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
detector: 待注册的探测器实例。
|
||||||
|
"""
|
||||||
|
self.detectors.append(detector)
|
||||||
|
|
||||||
|
def list_detectors(self) -> List[BaseSchemaVersionDetector]:
|
||||||
|
"""返回当前已注册的全部探测器。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[BaseSchemaVersionDetector]: 已注册探测器列表副本。
|
||||||
|
"""
|
||||||
|
return list(self.detectors)
|
||||||
|
|
||||||
|
def resolve(self, connection: Connection) -> ResolvedSchemaVersion:
|
||||||
|
"""解析当前数据库的 schema 版本信息。
|
||||||
|
|
||||||
|
解析顺序如下:
|
||||||
|
1. 优先读取 ``PRAGMA user_version``。
|
||||||
|
2. 若其值为 0,则对数据库结构做快照。
|
||||||
|
3. 若数据库为空,则返回空库版本。
|
||||||
|
4. 若数据库非空,则交给探测器链进行识别。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResolvedSchemaVersion: 解析后的数据库版本信息。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||||
|
UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。
|
||||||
|
"""
|
||||||
|
recorded_version = self.version_store.read_version(connection)
|
||||||
|
if recorded_version > 0:
|
||||||
|
return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA)
|
||||||
|
|
||||||
|
snapshot = self.schema_inspector.inspect(connection)
|
||||||
|
if snapshot.is_empty():
|
||||||
|
return ResolvedSchemaVersion(
|
||||||
|
version=0,
|
||||||
|
source=SchemaVersionSource.EMPTY_DATABASE,
|
||||||
|
snapshot=snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._detect_unversioned_database(snapshot)
|
||||||
|
|
||||||
|
def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion:
|
||||||
|
"""识别未标记版本的历史数据库。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
snapshot: 当前数据库结构快照。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ResolvedSchemaVersion: 探测器识别出的版本信息。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。
|
||||||
|
UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。
|
||||||
|
"""
|
||||||
|
for detector in self.detectors:
|
||||||
|
detected_version = detector.detect_version(snapshot)
|
||||||
|
if detected_version is None:
|
||||||
|
continue
|
||||||
|
if detected_version < 0:
|
||||||
|
raise DatabaseMigrationVersionError(
|
||||||
|
f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}"
|
||||||
|
)
|
||||||
|
return ResolvedSchemaVersion(
|
||||||
|
version=detected_version,
|
||||||
|
source=SchemaVersionSource.DETECTOR,
|
||||||
|
detector_name=detector.name,
|
||||||
|
snapshot=snapshot,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。")
|
||||||
98
src/common/database/migrations/schema.py
Normal file
98
src/common/database/migrations/schema.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
"""SQLite 数据库结构探测工具。"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteSchemaInspector:
|
||||||
|
"""SQLite 数据库结构探测器。"""
|
||||||
|
|
||||||
|
def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot:
|
||||||
|
"""提取数据库中的全部用户表结构快照。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DatabaseSchemaSnapshot: 当前数据库结构快照。
|
||||||
|
"""
|
||||||
|
tables: Dict[str, TableSchema] = {}
|
||||||
|
for table_name in self.list_user_tables(connection):
|
||||||
|
table_schema = self.get_table_schema(connection, table_name)
|
||||||
|
tables[table_name] = table_schema
|
||||||
|
return DatabaseSchemaSnapshot(tables=tables)
|
||||||
|
|
||||||
|
def list_user_tables(self, connection: Connection) -> List[str]:
|
||||||
|
"""列出数据库中的全部用户表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 按字母顺序排列的用户表名列表。
|
||||||
|
"""
|
||||||
|
statement = text(
|
||||||
|
"""
|
||||||
|
SELECT name
|
||||||
|
FROM sqlite_master
|
||||||
|
WHERE type = 'table'
|
||||||
|
AND name NOT LIKE 'sqlite_%'
|
||||||
|
ORDER BY name
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
rows = connection.execute(statement).fetchall()
|
||||||
|
return [str(row[0]) for row in rows]
|
||||||
|
|
||||||
|
def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema:
|
||||||
|
"""获取指定表的结构信息。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
table_name: 待读取结构的表名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TableSchema: 指定表的结构快照。
|
||||||
|
"""
|
||||||
|
quoted_table_name = self._quote_identifier(table_name)
|
||||||
|
rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all()
|
||||||
|
|
||||||
|
columns: Dict[str, ColumnSchema] = {}
|
||||||
|
for row in rows:
|
||||||
|
column_schema = ColumnSchema(
|
||||||
|
name=str(row["name"]),
|
||||||
|
declared_type=str(row["type"] or ""),
|
||||||
|
default_value=None if row["dflt_value"] is None else str(row["dflt_value"]),
|
||||||
|
is_not_null=bool(row["notnull"]),
|
||||||
|
primary_key_position=int(row["pk"]),
|
||||||
|
)
|
||||||
|
columns[column_schema.name] = column_schema
|
||||||
|
|
||||||
|
return TableSchema(name=table_name, columns=columns)
|
||||||
|
|
||||||
|
def table_exists(self, connection: Connection, table_name: str) -> bool:
|
||||||
|
"""判断数据库中是否存在指定表。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
table_name: 待检查的表名。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 若表存在则返回 ``True``。
|
||||||
|
"""
|
||||||
|
return table_name in self.list_user_tables(connection)
|
||||||
|
|
||||||
|
def _quote_identifier(self, identifier: str) -> str:
|
||||||
|
"""为 SQLite 标识符添加安全引号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
identifier: 待引用的 SQLite 标识符。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 可直接拼接到 PRAGMA 语句中的安全标识符。
|
||||||
|
"""
|
||||||
|
escaped_identifier = identifier.replace('"', '""')
|
||||||
|
return f'"{escaped_identifier}"'
|
||||||
57
src/common/database/migrations/version_store.py
Normal file
57
src/common/database/migrations/version_store.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""SQLite 数据库版本存储实现。"""
|
||||||
|
|
||||||
|
from sqlalchemy.engine import Connection
|
||||||
|
|
||||||
|
from .exceptions import DatabaseMigrationVersionError
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteUserVersionStore:
|
||||||
|
"""基于 ``PRAGMA user_version`` 的 SQLite 版本存储器。"""
|
||||||
|
|
||||||
|
def read_version(self, connection: Connection) -> int:
|
||||||
|
"""读取当前数据库的 schema 版本号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 数据库记录的 schema 版本号。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationVersionError: 当读取结果异常或版本号非法时抛出。
|
||||||
|
"""
|
||||||
|
row = connection.exec_driver_sql("PRAGMA user_version").first()
|
||||||
|
if row is None or len(row) == 0:
|
||||||
|
raise DatabaseMigrationVersionError("读取 SQLite user_version 失败,返回结果为空。")
|
||||||
|
|
||||||
|
version = row[0]
|
||||||
|
if not isinstance(version, int):
|
||||||
|
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不是整数: {version!r}")
|
||||||
|
if version < 0:
|
||||||
|
raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不能为负数: {version}")
|
||||||
|
return version
|
||||||
|
|
||||||
|
def write_version(self, connection: Connection, version: int) -> None:
|
||||||
|
"""写入新的 schema 版本号。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connection: 当前数据库连接。
|
||||||
|
version: 待写入的 schema 版本号。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationVersionError: 当版本号非法时抛出。
|
||||||
|
"""
|
||||||
|
self._validate_version(version)
|
||||||
|
connection.exec_driver_sql(f"PRAGMA user_version = {version}")
|
||||||
|
|
||||||
|
def _validate_version(self, version: int) -> None:
|
||||||
|
"""校验版本号是否合法。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version: 待校验的版本号。
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseMigrationVersionError: 当版本号非法时抛出。
|
||||||
|
"""
|
||||||
|
if version < 0:
|
||||||
|
raise DatabaseMigrationVersionError(f"SQLite user_version 不能小于 0: {version}")
|
||||||
Reference in New Issue
Block a user