"""数据库迁移基础设施测试。""" from pathlib import Path from typing import List, Optional, Tuple from sqlalchemy import text from sqlalchemy.engine import Connection, Engine from sqlmodel import SQLModel, create_engine import json import msgpack import pytest from src.common.database import database as database_module from src.common.database.migrations import ( BaseSchemaVersionDetector, BaseMigrationProgressReporter, DatabaseSchemaSnapshot, DatabaseMigrationBootstrapper, DatabaseMigrationState, DatabaseMigrationManager, EMPTY_SCHEMA_VERSION, LATEST_SCHEMA_VERSION, LEGACY_V1_SCHEMA_VERSION, MigrationExecutionContext, MigrationPlan, MigrationRegistry, MigrationStep, ResolvedSchemaVersion, SchemaVersionResolver, SchemaVersionSource, SQLiteSchemaInspector, SQLiteUserVersionStore, build_default_migration_registry, build_default_schema_version_resolver, create_database_migration_bootstrapper, ) class FixedVersionDetector(BaseSchemaVersionDetector): """测试用固定版本探测器。""" @property def name(self) -> str: """返回测试探测器名称。 Returns: str: 探测器名称。 """ return "fixed_version_detector" def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: """根据测试表是否存在返回固定版本。 Args: snapshot: 当前数据库结构快照。 Returns: Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。 """ if snapshot.has_table("legacy_records"): return 2 return None class FakeMigrationProgressReporter(BaseMigrationProgressReporter): """测试用迁移进度上报器。""" def __init__(self) -> None: """初始化测试用进度上报器。""" self.events: List[Tuple[str, Optional[int], Optional[int], Optional[str]]] = [] def open(self) -> None: """记录打开事件。""" self.events.append(("open", None, None, None)) def close(self) -> None: """记录关闭事件。""" self.events.append(("close", None, None, None)) def start( self, total_records: int, total_tables: int, description: str = "总迁移进度", table_unit_name: str = "表", record_unit_name: str = "记录", ) -> None: """记录启动事件。 Args: total_records: 任务记录总数。 total_tables: 任务表总数。 description: 任务描述。 table_unit_name: 表级进度单位名称。 record_unit_name: 记录级进度单位名称。 """ del table_unit_name, record_unit_name self.events.append(("start", total_records, total_tables, description)) def advance( self, records: int = 0, completed_tables: int = 0, item_name: Optional[str] = None, ) -> None: """记录推进事件。 Args: records: 推进的记录数。 completed_tables: 已完成的表数。 item_name: 当前完成的项目名称。 """ self.events.append(("advance", records, completed_tables, item_name)) def _create_sqlite_engine(database_file: Path) -> Engine: """创建测试用 SQLite 引擎。 Args: database_file: 测试数据库文件路径。 Returns: Engine: SQLite 引擎实例。 """ return create_engine( f"sqlite:///{database_file}", echo=False, connect_args={"check_same_thread": False}, ) def _create_current_schema(connection: Connection) -> None: """创建当前最新版本的数据库结构。 Args: connection: 当前数据库连接。 """ import src.common.database.database_model # noqa: F401 SQLModel.metadata.create_all(connection) def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None: """创建带示例数据的旧版 ``0.x`` 数据库结构。 Args: connection: 当前数据库连接。 """ connection.execute( text( """ CREATE TABLE chat_streams ( id INTEGER PRIMARY KEY, stream_id TEXT NOT NULL, create_time REAL NOT NULL, last_active_time REAL NOT NULL, platform TEXT NOT NULL, user_id TEXT, group_id TEXT, group_name TEXT ) """ ) ) connection.execute( text( """ CREATE TABLE messages ( id INTEGER PRIMARY KEY, message_id TEXT NOT NULL, time REAL NOT NULL, chat_id TEXT NOT NULL, chat_info_platform TEXT, user_id TEXT, user_nickname TEXT, chat_info_group_id TEXT, chat_info_group_name TEXT, is_mentioned INTEGER, is_at INTEGER, processed_plain_text TEXT, display_message TEXT, is_emoji INTEGER, is_picid INTEGER, is_command INTEGER, is_notify INTEGER, additional_config TEXT, priority_mode TEXT ) """ ) ) connection.execute( text( """ CREATE TABLE action_records ( id INTEGER PRIMARY KEY, action_id TEXT NOT NULL, time REAL NOT NULL, action_reasoning TEXT, action_name TEXT NOT NULL, action_data TEXT, action_prompt_display TEXT, chat_id TEXT ) """ ) ) connection.execute( text( """ CREATE TABLE expression ( id INTEGER PRIMARY KEY, situation TEXT NOT NULL, style TEXT NOT NULL, content_list TEXT, count INTEGER, last_active_time REAL NOT NULL, chat_id TEXT, create_date REAL, checked INTEGER, rejected INTEGER, modified_by TEXT ) """ ) ) connection.execute( text( """ CREATE TABLE jargon ( id INTEGER PRIMARY KEY, content TEXT NOT NULL, raw_content TEXT, meaning TEXT, chat_id TEXT, is_global INTEGER, count INTEGER, is_jargon INTEGER, last_inference_count INTEGER, is_complete INTEGER, inference_with_context TEXT, inference_content_only TEXT ) """ ) ) connection.execute( text( """ INSERT INTO chat_streams ( id, stream_id, create_time, last_active_time, platform, user_id, group_id, group_name ) VALUES ( 1, 'session-1', 1710000000.0, 1710000300.0, 'qq', 'user-1', 'group-1', '测试群' ) """ ) ) connection.execute( text( """ INSERT INTO messages ( id, message_id, time, chat_id, chat_info_platform, user_id, user_nickname, chat_info_group_id, chat_info_group_name, is_mentioned, is_at, processed_plain_text, display_message, is_emoji, is_picid, is_command, is_notify, additional_config, priority_mode ) VALUES ( 1, 'msg-1', 1710000010.0, 'session-1', 'qq', 'user-1', '测试用户', 'group-1', '测试群', 1, 0, '你好', '你好呀', 0, 1, 0, 1, '{"source":"legacy"}', 'high' ) """ ) ) connection.execute( text( """ INSERT INTO action_records ( id, action_id, time, action_reasoning, action_name, action_data, action_prompt_display, chat_id ) VALUES ( 1, 'action-1', 1710000020.0, '需要调用工具', 'search', '{"query":"MaiBot"}', '执行搜索', 'session-1' ) """ ) ) connection.execute( text( """ INSERT INTO expression ( id, situation, style, content_list, count, last_active_time, chat_id, create_date, checked, rejected, modified_by ) VALUES ( 1, '打招呼', '可爱', '["你好呀","早上好"]', 3, 1710000030.0, 'session-1', 1710000040.0, 1, 0, 'ai' ) """ ) ) connection.execute( text( """ INSERT INTO jargon ( id, content, raw_content, meaning, chat_id, is_global, count, is_jargon, last_inference_count, is_complete, inference_with_context, inference_content_only ) VALUES ( 1, '上分', '["上分"]', '提高排名', 'session-1', 0, 5, 1, 2, 1, '{"guess":"context"}', '{"guess":"content"}' ) """ ) ) def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None: """应支持读取与写入 SQLite ``user_version``。""" engine = _create_sqlite_engine(tmp_path / "version_store.db") version_store = SQLiteUserVersionStore() with engine.begin() as connection: assert version_store.read_version(connection) == 0 version_store.write_version(connection, 7) with engine.connect() as connection: assert version_store.read_version(connection) == 7 def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None: """应能提取 SQLite 数据库的表与列结构。""" engine = _create_sqlite_engine(tmp_path / "schema_inspector.db") inspector = SQLiteSchemaInspector() with engine.begin() as connection: connection.execute( text( """ CREATE TABLE legacy_records ( id INTEGER PRIMARY KEY, payload TEXT NOT NULL, created_at TEXT ) """ ) ) with engine.connect() as connection: snapshot = inspector.inspect(connection) assert snapshot.has_table("legacy_records") assert snapshot.has_column("legacy_records", "payload") assert not snapshot.has_column("legacy_records", "missing_column") table_schema = snapshot.get_table("legacy_records") assert table_schema is not None assert table_schema.column_names() == ["created_at", "id", "payload"] def test_resolver_can_identify_empty_database(tmp_path: Path) -> None: """空数据库应被解析为版本 ``0``。""" engine = _create_sqlite_engine(tmp_path / "empty_resolver.db") resolver = SchemaVersionResolver() with engine.connect() as connection: resolved_version = resolver.resolve(connection) assert resolved_version.version == 0 assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE assert resolved_version.snapshot is not None assert resolved_version.snapshot.is_empty() def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None: """未写入 ``user_version`` 的历史库应支持通过探测器识别版本。""" engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db") resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()]) with engine.begin() as connection: connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)")) with engine.connect() as connection: resolved_version = resolver.resolve(connection) assert resolved_version.version == 2 assert resolved_version.source == SchemaVersionSource.DETECTOR assert resolved_version.detector_name == "fixed_version_detector" def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None: """迁移编排器应能按顺序执行已注册步骤并更新版本号。""" engine = _create_sqlite_engine(tmp_path / "manager.db") executed_steps: List[str] = [] def migrate_0_to_1(context: MigrationExecutionContext) -> None: """测试迁移步骤 0 -> 1。 Args: context: 当前迁移步骤执行上下文。 """ executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1") context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")) def migrate_1_to_2(context: MigrationExecutionContext) -> None: """测试迁移步骤 1 -> 2。 Args: context: 当前迁移步骤执行上下文。 """ executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2") context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT")) registry = MigrationRegistry( steps=[ MigrationStep( version_from=0, version_to=1, name="create_sample_records", description="创建示例表。", handler=migrate_0_to_1, ), MigrationStep( version_from=1, version_to=2, name="add_sample_email", description="为示例表增加邮箱字段。", handler=migrate_1_to_2, ), ] ) manager = DatabaseMigrationManager(engine=engine, registry=registry) migration_plan = manager.migrate() assert migration_plan.step_count() == 2 assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"] with engine.connect() as connection: version_store = SQLiteUserVersionStore() snapshot = SQLiteSchemaInspector().inspect(connection) recorded_version = version_store.read_version(connection) assert recorded_version == 2 assert snapshot.has_table("sample_records") assert snapshot.has_column("sample_records", "email") def test_manager_can_report_step_progress(tmp_path: Path) -> None: """迁移编排器应支持通过上下文上报步骤进度。""" engine = _create_sqlite_engine(tmp_path / "manager_progress.db") reporter_instances: List[FakeMigrationProgressReporter] = [] def _build_reporter() -> BaseMigrationProgressReporter: """构建测试用进度上报器。 Returns: BaseMigrationProgressReporter: 测试用进度上报器实例。 """ reporter = FakeMigrationProgressReporter() reporter_instances.append(reporter) return reporter def migrate_1_to_2(context: MigrationExecutionContext) -> None: """测试迁移步骤 ``1 -> 2`` 的进度上报。 Args: context: 当前迁移步骤执行上下文。 """ context.start_progress(total_tables=3, total_records=30, description="总迁移进度") context.advance_progress(records=10, completed_tables=1, item_name="chat_sessions") context.advance_progress(records=10, completed_tables=1, item_name="mai_messages") context.advance_progress(records=10, completed_tables=1, item_name="tool_records") context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")) with engine.begin() as connection: SQLiteUserVersionStore().write_version(connection, 1) registry = MigrationRegistry( steps=[ MigrationStep( version_from=1, version_to=2, name="progress_step", description="测试进度上报。", handler=migrate_1_to_2, ) ] ) manager = DatabaseMigrationManager( engine=engine, registry=registry, progress_reporter_factory=_build_reporter, ) migration_plan = manager.migrate() assert migration_plan.step_count() == 1 assert len(reporter_instances) == 1 assert reporter_instances[0].events == [ ("open", None, None, None), ("start", 30, 3, "总迁移进度"), ("advance", 10, 1, "chat_sessions"), ("advance", 10, 1, "mai_messages"), ("advance", 10, 1, "tool_records"), ("close", None, None, None), ] def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None: """默认解析器应能识别未写入版本号的最新结构数据库。""" engine = _create_sqlite_engine(tmp_path / "latest_resolver.db") resolver = build_default_schema_version_resolver() with engine.begin() as connection: _create_current_schema(connection) with engine.connect() as connection: resolved_version = resolver.resolve(connection) assert resolved_version.version == LATEST_SCHEMA_VERSION assert resolved_version.source == SchemaVersionSource.DETECTOR assert resolved_version.detector_name == "latest_schema_detector" def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None: """默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。""" engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db") resolver = build_default_schema_version_resolver() with engine.begin() as connection: _create_legacy_v1_schema_with_sample_data(connection) with engine.connect() as connection: resolved_version = resolver.resolve(connection) assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION assert resolved_version.source == SchemaVersionSource.DETECTOR assert resolved_version.detector_name == "legacy_v1_schema_detector" def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None: """已是最新结构但未写版本号的数据库应直接补写 ``user_version``。""" engine = _create_sqlite_engine(tmp_path / "latest_finalize.db") bootstrapper = create_database_migration_bootstrapper(engine) with engine.begin() as connection: _create_current_schema(connection) migration_state = bootstrapper.prepare_database() bootstrapper.finalize_database(migration_state) assert not migration_state.requires_migration() assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR with engine.connect() as connection: recorded_version = SQLiteUserVersionStore().read_version(connection) assert recorded_version == LATEST_SCHEMA_VERSION def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None: """空库在建表完成后应回写最新 ``user_version``。""" engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db") bootstrapper = create_database_migration_bootstrapper(engine) migration_state = bootstrapper.prepare_database() assert not migration_state.requires_migration() assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION assert migration_state.target_version == LATEST_SCHEMA_VERSION with engine.begin() as connection: _create_current_schema(connection) bootstrapper.finalize_database(migration_state) with engine.connect() as connection: recorded_version = SQLiteUserVersionStore().read_version(connection) assert recorded_version == LATEST_SCHEMA_VERSION def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None: """启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。""" engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db") execution_marks: List[str] = [] def migrate_1_to_2(context: MigrationExecutionContext) -> None: """测试桥接器迁移步骤 ``1 -> 2``。 Args: context: 当前迁移步骤执行上下文。 """ execution_marks.append(f"step={context.step_name},index={context.step_index}") context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT")) with engine.begin() as connection: connection.execute( text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)") ) SQLiteUserVersionStore().write_version(connection, 1) registry = MigrationRegistry( steps=[ MigrationStep( version_from=1, version_to=2, name="bootstrap_add_email", description="为桥接器测试表增加邮箱字段。", handler=migrate_1_to_2, ) ] ) bootstrapper = DatabaseMigrationBootstrapper( manager=DatabaseMigrationManager(engine=engine, registry=registry), latest_schema_version=2, ) migration_state = bootstrapper.prepare_database() assert migration_state.resolved_version.version == 2 assert migration_state.target_version == 2 assert execution_marks == ["step=bootstrap_add_email,index=1"] with engine.connect() as connection: snapshot = SQLiteSchemaInspector().inspect(connection) recorded_version = SQLiteUserVersionStore().read_version(connection) assert recorded_version == 2 assert snapshot.has_table("bootstrap_records") assert snapshot.has_column("bootstrap_records", "email") def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None: """默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。""" engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db") bootstrapper = create_database_migration_bootstrapper(engine) with engine.begin() as connection: _create_legacy_v1_schema_with_sample_data(connection) migration_state = bootstrapper.prepare_database() bootstrapper.finalize_database(migration_state) assert not migration_state.requires_migration() assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA with engine.connect() as connection: recorded_version = SQLiteUserVersionStore().read_version(connection) snapshot = SQLiteSchemaInspector().inspect(connection) message_row = connection.execute( text( """ SELECT session_id, processed_plain_text, additional_config, raw_content FROM mai_messages WHERE message_id = 'msg-1' """ ) ).mappings().one() tool_row = connection.execute( text( """ SELECT session_id, tool_name, tool_display_prompt FROM tool_records WHERE tool_id = 'action-1' """ ) ).mappings().one() expression_row = connection.execute( text( """ SELECT session_id, content_list, modified_by FROM expressions WHERE id = 1 """ ) ).mappings().one() jargon_row = connection.execute( text( """ SELECT session_id_dict, raw_content, inference_with_content_only FROM jargons WHERE id = 1 """ ) ).mappings().one() assert recorded_version == LATEST_SCHEMA_VERSION assert snapshot.has_table("__legacy_v1_messages") assert snapshot.has_table("chat_sessions") assert snapshot.has_table("mai_messages") assert snapshot.has_table("tool_records") assert not snapshot.has_table("action_records") assert not snapshot.has_column("mai_messages", "display_message") unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False) additional_config = json.loads(message_row["additional_config"]) expression_content_list = json.loads(expression_row["content_list"]) jargon_session_id_dict = json.loads(jargon_row["session_id_dict"]) jargon_raw_content = json.loads(jargon_row["raw_content"]) assert message_row["session_id"] == "session-1" assert message_row["processed_plain_text"] == "你好" assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}] assert additional_config == {"priority_mode": "high", "source": "legacy"} assert tool_row["session_id"] == "session-1" assert tool_row["tool_name"] == "search" assert tool_row["tool_display_prompt"] == "执行搜索" assert expression_row["session_id"] == "session-1" assert expression_row["modified_by"] == "AI" assert expression_content_list == ["你好呀", "早上好"] assert jargon_session_id_dict == {"session-1": 5} assert jargon_raw_content == ["上分"] assert jargon_row["inference_with_content_only"] == '{"guess":"content"}' def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None: """旧版迁移步骤应按目标表数量推进总进度。""" engine = _create_sqlite_engine(tmp_path / "legacy_progress.db") reporter_instances: List[FakeMigrationProgressReporter] = [] def _build_reporter() -> BaseMigrationProgressReporter: """构建测试用进度上报器。 Returns: BaseMigrationProgressReporter: 测试用进度上报器实例。 """ reporter = FakeMigrationProgressReporter() reporter_instances.append(reporter) return reporter with engine.begin() as connection: _create_legacy_v1_schema_with_sample_data(connection) manager = DatabaseMigrationManager( engine=engine, registry=build_default_migration_registry(), resolver=build_default_schema_version_resolver(), progress_reporter_factory=_build_reporter, ) migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION) assert migration_plan.step_count() == 3 assert len(reporter_instances) == 3 reporter_events = reporter_instances[0].events assert reporter_events[0] == ("open", None, None, None) assert reporter_events[1] == ("start", 6, 12, "总迁移进度") assert reporter_events[-1] == ("close", None, None, None) assert reporter_events.count(("advance", 1, 0, None)) == 6 assert reporter_events.count(("advance", 0, 1, "chat_sessions")) == 1 assert reporter_events.count(("advance", 0, 1, "thinking_questions")) == 1 assert len([event for event in reporter_events if event[0] == "advance"]) == 18 def test_initialize_database_calls_bootstrapper_before_create_all( monkeypatch: pytest.MonkeyPatch, tmp_path: Path, ) -> None: """数据库初始化入口应先准备迁移,再建表、补迁移并收尾。""" call_order: List[str] = [] def _fake_prepare_database() -> DatabaseMigrationState: """返回测试用迁移状态。 Returns: DatabaseMigrationState: 不包含迁移步骤的测试状态。 """ call_order.append("prepare_database") return DatabaseMigrationState( resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE), target_version=LATEST_SCHEMA_VERSION, plan=MigrationPlan( current_version=EMPTY_SCHEMA_VERSION, target_version=LATEST_SCHEMA_VERSION, steps=[], ), ) def _fake_create_all(bind) -> None: """记录建表调用。 Args: bind: 传入的数据库绑定对象。 """ del bind call_order.append("create_all") def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None: """记录迁移收尾调用。 Args: migration_state: 当前数据库迁移状态。 """ del migration_state call_order.append("finalize_database") monkeypatch.setattr(database_module, "_db_initialized", False) monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data") monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database) monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database) monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all) database_module.initialize_database() assert call_order == [ "prepare_database", "create_all", "finalize_database", ]