From c2c992ff01c39b778c8fd0639bd40229a5d72216 Mon Sep 17 00:00:00 2001 From: DrSmoothl <1787882683@qq.com> Date: Tue, 31 Mar 2026 09:16:25 +0800 Subject: [PATCH] feat(database-migrations): implement database migration manager and related components - Add DatabaseMigrationManager for orchestrating database migrations, including planning and executing migration steps. - Introduce models for migration state, execution context, and migration steps. - Implement MigrationPlanner to generate migration plans based on current and target versions. - Create MigrationRegistry for registering and managing migration steps. - Develop SchemaVersionResolver to determine the current database schema version. - Add SQLiteSchemaInspector for inspecting SQLite database structures. - Implement progress reporting tools using rich for visualizing migration progress. - Introduce SQLiteUserVersionStore for managing schema version storage in SQLite. --- .../test_database_migration_foundation.py | 912 +++++++++++ src/common/database/database.py | 33 +- src/common/database/migrations/__init__.py | 79 + src/common/database/migrations/bootstrap.py | 171 ++ src/common/database/migrations/builtin.py | 159 ++ src/common/database/migrations/exceptions.py | 33 + .../database/migrations/legacy_v1_to_v2.py | 1384 +++++++++++++++++ src/common/database/migrations/manager.py | 205 +++ src/common/database/migrations/models.py | 285 ++++ src/common/database/migrations/planner.py | 108 ++ src/common/database/migrations/progress.py | 272 ++++ src/common/database/migrations/registry.py | 98 ++ src/common/database/migrations/resolver.py | 135 ++ src/common/database/migrations/schema.py | 98 ++ .../database/migrations/version_store.py | 57 + 15 files changed, 4025 insertions(+), 4 deletions(-) create mode 100644 pytests/common_test/test_database_migration_foundation.py create mode 100644 src/common/database/migrations/__init__.py create mode 100644 src/common/database/migrations/bootstrap.py create mode 100644 src/common/database/migrations/builtin.py create mode 100644 src/common/database/migrations/exceptions.py create mode 100644 src/common/database/migrations/legacy_v1_to_v2.py create mode 100644 src/common/database/migrations/manager.py create mode 100644 src/common/database/migrations/models.py create mode 100644 src/common/database/migrations/planner.py create mode 100644 src/common/database/migrations/progress.py create mode 100644 src/common/database/migrations/registry.py create mode 100644 src/common/database/migrations/resolver.py create mode 100644 src/common/database/migrations/schema.py create mode 100644 src/common/database/migrations/version_store.py diff --git a/pytests/common_test/test_database_migration_foundation.py b/pytests/common_test/test_database_migration_foundation.py new file mode 100644 index 00000000..9c930744 --- /dev/null +++ b/pytests/common_test/test_database_migration_foundation.py @@ -0,0 +1,912 @@ +"""数据库迁移基础设施测试。""" + +from pathlib import Path +from typing import List, Optional, Tuple + +from sqlalchemy import text +from sqlalchemy.engine import Connection, Engine +from sqlmodel import SQLModel, create_engine + +import json +import msgpack +import pytest + +from src.common.database import database as database_module +from src.common.database.migrations import ( + BaseSchemaVersionDetector, + BaseMigrationProgressReporter, + DatabaseSchemaSnapshot, + DatabaseMigrationBootstrapper, + DatabaseMigrationState, + DatabaseMigrationManager, + EMPTY_SCHEMA_VERSION, + LATEST_SCHEMA_VERSION, + LEGACY_V1_SCHEMA_VERSION, + MigrationExecutionContext, + MigrationPlan, + MigrationRegistry, + MigrationStep, + ResolvedSchemaVersion, + SchemaVersionResolver, + SchemaVersionSource, + SQLiteSchemaInspector, + SQLiteUserVersionStore, + build_default_migration_registry, + build_default_schema_version_resolver, + create_database_migration_bootstrapper, +) + + +class FixedVersionDetector(BaseSchemaVersionDetector): + """测试用固定版本探测器。""" + + @property + def name(self) -> str: + """返回测试探测器名称。 + + Returns: + str: 探测器名称。 + """ + return "fixed_version_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """根据测试表是否存在返回固定版本。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若存在测试表则返回固定版本,否则返回 ``None``。 + """ + if snapshot.has_table("legacy_records"): + return 2 + return None + + +class FakeMigrationProgressReporter(BaseMigrationProgressReporter): + """测试用迁移进度上报器。""" + + def __init__(self) -> None: + """初始化测试用进度上报器。""" + self.events: List[Tuple[str, Optional[int], Optional[str], Optional[str]]] = [] + + def open(self) -> None: + """记录打开事件。""" + self.events.append(("open", None, None, None)) + + def close(self) -> None: + """记录关闭事件。""" + self.events.append(("close", None, None, None)) + + def start( + self, + total: int, + description: str = "总迁移进度", + unit_name: str = "表", + ) -> None: + """记录启动事件。 + + Args: + total: 任务总数。 + description: 任务描述。 + unit_name: 进度单位名称。 + """ + self.events.append(("start", total, description, unit_name)) + + def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + """记录推进事件。 + + Args: + advance: 推进步数。 + item_name: 当前完成的项目名称。 + """ + self.events.append(("advance", advance, item_name, None)) + + +def _create_sqlite_engine(database_file: Path) -> Engine: + """创建测试用 SQLite 引擎。 + + Args: + database_file: 测试数据库文件路径。 + + Returns: + Engine: SQLite 引擎实例。 + """ + return create_engine( + f"sqlite:///{database_file}", + echo=False, + connect_args={"check_same_thread": False}, + ) + + +def _create_current_schema(connection: Connection) -> None: + """创建当前最新版本的数据库结构。 + + Args: + connection: 当前数据库连接。 + """ + import src.common.database.database_model # noqa: F401 + + SQLModel.metadata.create_all(connection) + + +def _create_legacy_v1_schema_with_sample_data(connection: Connection) -> None: + """创建带示例数据的旧版 ``0.x`` 数据库结构。 + + Args: + connection: 当前数据库连接。 + """ + connection.execute( + text( + """ + CREATE TABLE chat_streams ( + id INTEGER PRIMARY KEY, + stream_id TEXT NOT NULL, + create_time REAL NOT NULL, + last_active_time REAL NOT NULL, + platform TEXT NOT NULL, + user_id TEXT, + group_id TEXT, + group_name TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE messages ( + id INTEGER PRIMARY KEY, + message_id TEXT NOT NULL, + time REAL NOT NULL, + chat_id TEXT NOT NULL, + chat_info_platform TEXT, + user_id TEXT, + user_nickname TEXT, + chat_info_group_id TEXT, + chat_info_group_name TEXT, + is_mentioned INTEGER, + is_at INTEGER, + processed_plain_text TEXT, + display_message TEXT, + is_emoji INTEGER, + is_picid INTEGER, + is_command INTEGER, + is_notify INTEGER, + additional_config TEXT, + priority_mode TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE action_records ( + id INTEGER PRIMARY KEY, + action_id TEXT NOT NULL, + time REAL NOT NULL, + action_reasoning TEXT, + action_name TEXT NOT NULL, + action_data TEXT, + action_prompt_display TEXT, + chat_id TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE expression ( + id INTEGER PRIMARY KEY, + situation TEXT NOT NULL, + style TEXT NOT NULL, + content_list TEXT, + count INTEGER, + last_active_time REAL NOT NULL, + chat_id TEXT, + create_date REAL, + checked INTEGER, + rejected INTEGER, + modified_by TEXT + ) + """ + ) + ) + connection.execute( + text( + """ + CREATE TABLE jargon ( + id INTEGER PRIMARY KEY, + content TEXT NOT NULL, + raw_content TEXT, + meaning TEXT, + chat_id TEXT, + is_global INTEGER, + count INTEGER, + is_jargon INTEGER, + last_inference_count INTEGER, + is_complete INTEGER, + inference_with_context TEXT, + inference_content_only TEXT + ) + """ + ) + ) + + connection.execute( + text( + """ + INSERT INTO chat_streams ( + id, + stream_id, + create_time, + last_active_time, + platform, + user_id, + group_id, + group_name + ) VALUES ( + 1, + 'session-1', + 1710000000.0, + 1710000300.0, + 'qq', + 'user-1', + 'group-1', + '测试群' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO messages ( + id, + message_id, + time, + chat_id, + chat_info_platform, + user_id, + user_nickname, + chat_info_group_id, + chat_info_group_name, + is_mentioned, + is_at, + processed_plain_text, + display_message, + is_emoji, + is_picid, + is_command, + is_notify, + additional_config, + priority_mode + ) VALUES ( + 1, + 'msg-1', + 1710000010.0, + 'session-1', + 'qq', + 'user-1', + '测试用户', + 'group-1', + '测试群', + 1, + 0, + '你好', + '你好呀', + 0, + 1, + 0, + 1, + '{"source":"legacy"}', + 'high' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO action_records ( + id, + action_id, + time, + action_reasoning, + action_name, + action_data, + action_prompt_display, + chat_id + ) VALUES ( + 1, + 'action-1', + 1710000020.0, + '需要调用工具', + 'search', + '{"query":"MaiBot"}', + '执行搜索', + 'session-1' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO expression ( + id, + situation, + style, + content_list, + count, + last_active_time, + chat_id, + create_date, + checked, + rejected, + modified_by + ) VALUES ( + 1, + '打招呼', + '可爱', + '["你好呀","早上好"]', + 3, + 1710000030.0, + 'session-1', + 1710000040.0, + 1, + 0, + 'ai' + ) + """ + ) + ) + connection.execute( + text( + """ + INSERT INTO jargon ( + id, + content, + raw_content, + meaning, + chat_id, + is_global, + count, + is_jargon, + last_inference_count, + is_complete, + inference_with_context, + inference_content_only + ) VALUES ( + 1, + '上分', + '["上分"]', + '提高排名', + 'session-1', + 0, + 5, + 1, + 2, + 1, + '{"guess":"context"}', + '{"guess":"content"}' + ) + """ + ) + ) + + +def test_user_version_store_can_read_and_write_versions(tmp_path: Path) -> None: + """应支持读取与写入 SQLite ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "version_store.db") + version_store = SQLiteUserVersionStore() + + with engine.begin() as connection: + assert version_store.read_version(connection) == 0 + version_store.write_version(connection, 7) + + with engine.connect() as connection: + assert version_store.read_version(connection) == 7 + + +def test_schema_inspector_can_extract_tables_and_columns(tmp_path: Path) -> None: + """应能提取 SQLite 数据库的表与列结构。""" + engine = _create_sqlite_engine(tmp_path / "schema_inspector.db") + inspector = SQLiteSchemaInspector() + + with engine.begin() as connection: + connection.execute( + text( + """ + CREATE TABLE legacy_records ( + id INTEGER PRIMARY KEY, + payload TEXT NOT NULL, + created_at TEXT + ) + """ + ) + ) + + with engine.connect() as connection: + snapshot = inspector.inspect(connection) + + assert snapshot.has_table("legacy_records") + assert snapshot.has_column("legacy_records", "payload") + assert not snapshot.has_column("legacy_records", "missing_column") + table_schema = snapshot.get_table("legacy_records") + + assert table_schema is not None + assert table_schema.column_names() == ["created_at", "id", "payload"] + + +def test_resolver_can_identify_empty_database(tmp_path: Path) -> None: + """空数据库应被解析为版本 ``0``。""" + engine = _create_sqlite_engine(tmp_path / "empty_resolver.db") + resolver = SchemaVersionResolver() + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == 0 + assert resolved_version.source == SchemaVersionSource.EMPTY_DATABASE + assert resolved_version.snapshot is not None + assert resolved_version.snapshot.is_empty() + + +def test_resolver_can_use_detector_for_unversioned_legacy_database(tmp_path: Path) -> None: + """未写入 ``user_version`` 的历史库应支持通过探测器识别版本。""" + engine = _create_sqlite_engine(tmp_path / "legacy_resolver.db") + resolver = SchemaVersionResolver(detectors=[FixedVersionDetector()]) + + with engine.begin() as connection: + connection.execute(text("CREATE TABLE legacy_records (id INTEGER PRIMARY KEY, payload TEXT NOT NULL)")) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == 2 + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "fixed_version_detector" + + +def test_registry_and_manager_can_execute_registered_steps(tmp_path: Path) -> None: + """迁移编排器应能按顺序执行已注册步骤并更新版本号。""" + engine = _create_sqlite_engine(tmp_path / "manager.db") + executed_steps: List[str] = [] + + def migrate_0_to_1(context: MigrationExecutionContext) -> None: + """测试迁移步骤 0 -> 1。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + executed_steps.append(f"{context.current_version}->{context.target_version}:step_0_to_1") + context.connection.execute(text("CREATE TABLE sample_records (id INTEGER PRIMARY KEY, name TEXT NOT NULL)")) + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试迁移步骤 1 -> 2。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + executed_steps.append(f"{context.current_version}->{context.target_version}:step_1_to_2") + context.connection.execute(text("ALTER TABLE sample_records ADD COLUMN email TEXT")) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=0, + version_to=1, + name="create_sample_records", + description="创建示例表。", + handler=migrate_0_to_1, + ), + MigrationStep( + version_from=1, + version_to=2, + name="add_sample_email", + description="为示例表增加邮箱字段。", + handler=migrate_1_to_2, + ), + ] + ) + manager = DatabaseMigrationManager(engine=engine, registry=registry) + + migration_plan = manager.migrate() + + assert migration_plan.step_count() == 2 + assert executed_steps == ["0->2:step_0_to_1", "1->2:step_1_to_2"] + + with engine.connect() as connection: + version_store = SQLiteUserVersionStore() + snapshot = SQLiteSchemaInspector().inspect(connection) + recorded_version = version_store.read_version(connection) + + assert recorded_version == 2 + assert snapshot.has_table("sample_records") + assert snapshot.has_column("sample_records", "email") + + +def test_manager_can_report_step_progress(tmp_path: Path) -> None: + """迁移编排器应支持通过上下文上报步骤进度。""" + engine = _create_sqlite_engine(tmp_path / "manager_progress.db") + reporter_instances: List[FakeMigrationProgressReporter] = [] + + def _build_reporter() -> BaseMigrationProgressReporter: + """构建测试用进度上报器。 + + Returns: + BaseMigrationProgressReporter: 测试用进度上报器实例。 + """ + reporter = FakeMigrationProgressReporter() + reporter_instances.append(reporter) + return reporter + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试迁移步骤 ``1 -> 2`` 的进度上报。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + context.start_progress(total=3, description="总迁移进度", unit_name="表") + context.advance_progress(item_name="chat_sessions") + context.advance_progress(item_name="mai_messages") + context.advance_progress(item_name="tool_records") + context.connection.execute(text("CREATE TABLE progress_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)")) + + with engine.begin() as connection: + SQLiteUserVersionStore().write_version(connection, 1) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=1, + version_to=2, + name="progress_step", + description="测试进度上报。", + handler=migrate_1_to_2, + ) + ] + ) + manager = DatabaseMigrationManager( + engine=engine, + registry=registry, + progress_reporter_factory=_build_reporter, + ) + + migration_plan = manager.migrate() + + assert migration_plan.step_count() == 1 + assert len(reporter_instances) == 1 + assert reporter_instances[0].events == [ + ("open", None, None, None), + ("start", 3, "总迁移进度", "表"), + ("advance", 1, "chat_sessions", None), + ("advance", 1, "mai_messages", None), + ("advance", 1, "tool_records", None), + ("close", None, None, None), + ] + + +def test_default_resolver_can_identify_unversioned_latest_database(tmp_path: Path) -> None: + """默认解析器应能识别未写入版本号的最新结构数据库。""" + engine = _create_sqlite_engine(tmp_path / "latest_resolver.db") + resolver = build_default_schema_version_resolver() + + with engine.begin() as connection: + _create_current_schema(connection) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == LATEST_SCHEMA_VERSION + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "latest_schema_detector" + + +def test_default_resolver_can_identify_legacy_v1_database(tmp_path: Path) -> None: + """默认解析器应能识别未写版本号的旧版 ``0.x`` 数据库。""" + engine = _create_sqlite_engine(tmp_path / "legacy_v1_resolver.db") + resolver = build_default_schema_version_resolver() + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + with engine.connect() as connection: + resolved_version = resolver.resolve(connection) + + assert resolved_version.version == LEGACY_V1_SCHEMA_VERSION + assert resolved_version.source == SchemaVersionSource.DETECTOR + assert resolved_version.detector_name == "legacy_v1_schema_detector" + + +def test_bootstrapper_can_finalize_unversioned_latest_database(tmp_path: Path) -> None: + """已是最新结构但未写版本号的数据库应直接补写 ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "latest_finalize.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + with engine.begin() as connection: + _create_current_schema(connection) + + migration_state = bootstrapper.prepare_database() + bootstrapper.finalize_database(migration_state) + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION + assert migration_state.resolved_version.source == SchemaVersionSource.DETECTOR + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == LATEST_SCHEMA_VERSION + + +def test_bootstrapper_can_finalize_empty_database_to_latest_version(tmp_path: Path) -> None: + """空库在建表完成后应回写最新 ``user_version``。""" + engine = _create_sqlite_engine(tmp_path / "bootstrap_empty.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + migration_state = bootstrapper.prepare_database() + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == EMPTY_SCHEMA_VERSION + assert migration_state.target_version == LATEST_SCHEMA_VERSION + + with engine.begin() as connection: + _create_current_schema(connection) + + bootstrapper.finalize_database(migration_state) + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == LATEST_SCHEMA_VERSION + + +def test_bootstrapper_runs_registered_steps_for_versioned_database(tmp_path: Path) -> None: + """启动桥接器应在已登记旧版本数据库上执行注册迁移步骤。""" + engine = _create_sqlite_engine(tmp_path / "bootstrap_registered.db") + execution_marks: List[str] = [] + + def migrate_1_to_2(context: MigrationExecutionContext) -> None: + """测试桥接器迁移步骤 ``1 -> 2``。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + execution_marks.append(f"step={context.step_name},index={context.step_index}") + context.connection.execute(text("ALTER TABLE bootstrap_records ADD COLUMN email TEXT")) + + with engine.begin() as connection: + connection.execute( + text("CREATE TABLE bootstrap_records (id INTEGER PRIMARY KEY, value TEXT NOT NULL)") + ) + SQLiteUserVersionStore().write_version(connection, 1) + + registry = MigrationRegistry( + steps=[ + MigrationStep( + version_from=1, + version_to=2, + name="bootstrap_add_email", + description="为桥接器测试表增加邮箱字段。", + handler=migrate_1_to_2, + ) + ] + ) + bootstrapper = DatabaseMigrationBootstrapper( + manager=DatabaseMigrationManager(engine=engine, registry=registry), + latest_schema_version=2, + ) + + migration_state = bootstrapper.prepare_database() + + assert migration_state.resolved_version.version == 2 + assert migration_state.target_version == 2 + assert execution_marks == ["step=bootstrap_add_email,index=1"] + + with engine.connect() as connection: + snapshot = SQLiteSchemaInspector().inspect(connection) + recorded_version = SQLiteUserVersionStore().read_version(connection) + + assert recorded_version == 2 + assert snapshot.has_table("bootstrap_records") + assert snapshot.has_column("bootstrap_records", "email") + + +def test_default_bootstrapper_can_migrate_legacy_v1_database(tmp_path: Path) -> None: + """默认桥接器应能把旧版 ``0.x`` 数据库整体迁移到最新结构。""" + engine = _create_sqlite_engine(tmp_path / "legacy_v1_to_v2.db") + bootstrapper = create_database_migration_bootstrapper(engine) + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + migration_state = bootstrapper.prepare_database() + bootstrapper.finalize_database(migration_state) + + assert not migration_state.requires_migration() + assert migration_state.resolved_version.version == LATEST_SCHEMA_VERSION + assert migration_state.resolved_version.source == SchemaVersionSource.PRAGMA + + with engine.connect() as connection: + recorded_version = SQLiteUserVersionStore().read_version(connection) + snapshot = SQLiteSchemaInspector().inspect(connection) + message_row = connection.execute( + text( + """ + SELECT session_id, processed_plain_text, additional_config, raw_content + FROM mai_messages + WHERE message_id = 'msg-1' + """ + ) + ).mappings().one() + action_row = connection.execute( + text( + """ + SELECT session_id, action_name, action_display_prompt + FROM action_records + WHERE action_id = 'action-1' + """ + ) + ).mappings().one() + tool_row = connection.execute( + text( + """ + SELECT session_id, tool_name, tool_display_prompt + FROM tool_records + WHERE tool_id = 'action-1' + """ + ) + ).mappings().one() + expression_row = connection.execute( + text( + """ + SELECT session_id, content_list, modified_by + FROM expressions + WHERE id = 1 + """ + ) + ).mappings().one() + jargon_row = connection.execute( + text( + """ + SELECT session_id_dict, raw_content, inference_with_content_only + FROM jargons + WHERE id = 1 + """ + ) + ).mappings().one() + + assert recorded_version == LATEST_SCHEMA_VERSION + assert snapshot.has_table("__legacy_v1_messages") + assert snapshot.has_table("chat_sessions") + assert snapshot.has_table("mai_messages") + assert snapshot.has_table("tool_records") + + unpacked_raw_content = msgpack.unpackb(message_row["raw_content"], raw=False) + additional_config = json.loads(message_row["additional_config"]) + expression_content_list = json.loads(expression_row["content_list"]) + jargon_session_id_dict = json.loads(jargon_row["session_id_dict"]) + jargon_raw_content = json.loads(jargon_row["raw_content"]) + + assert message_row["session_id"] == "session-1" + assert message_row["processed_plain_text"] == "你好" + assert unpacked_raw_content == [{"type": "text", "data": "你好呀"}] + assert additional_config == {"priority_mode": "high", "source": "legacy"} + assert action_row["session_id"] == "session-1" + assert action_row["action_name"] == "search" + assert action_row["action_display_prompt"] == "执行搜索" + assert tool_row["session_id"] == "session-1" + assert tool_row["tool_name"] == "search" + assert tool_row["tool_display_prompt"] == "执行搜索" + assert expression_row["session_id"] == "session-1" + assert expression_row["modified_by"] == "AI" + assert expression_content_list == ["你好呀", "早上好"] + assert jargon_session_id_dict == {"session-1": 5} + assert jargon_raw_content == ["上分"] + assert jargon_row["inference_with_content_only"] == '{"guess":"content"}' + + +def test_legacy_v1_migration_reports_table_progress(tmp_path: Path) -> None: + """旧版迁移步骤应按目标表数量推进总进度。""" + engine = _create_sqlite_engine(tmp_path / "legacy_progress.db") + reporter_instances: List[FakeMigrationProgressReporter] = [] + + def _build_reporter() -> BaseMigrationProgressReporter: + """构建测试用进度上报器。 + + Returns: + BaseMigrationProgressReporter: 测试用进度上报器实例。 + """ + reporter = FakeMigrationProgressReporter() + reporter_instances.append(reporter) + return reporter + + with engine.begin() as connection: + _create_legacy_v1_schema_with_sample_data(connection) + + manager = DatabaseMigrationManager( + engine=engine, + registry=build_default_migration_registry(), + resolver=build_default_schema_version_resolver(), + progress_reporter_factory=_build_reporter, + ) + + migration_plan = manager.migrate(target_version=LATEST_SCHEMA_VERSION) + + assert migration_plan.step_count() == 1 + assert len(reporter_instances) == 1 + reporter_events = reporter_instances[0].events + + assert reporter_events[0] == ("open", None, None, None) + assert reporter_events[1] == ("start", 12, "总迁移进度", "表") + assert reporter_events[-1] == ("close", None, None, None) + assert reporter_events.count(("advance", 1, "chat_sessions", None)) == 1 + assert reporter_events.count(("advance", 1, "thinking_questions", None)) == 1 + assert len([event for event in reporter_events if event[0] == "advance"]) == 12 + + +def test_initialize_database_calls_bootstrapper_before_create_all( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +) -> None: + """数据库初始化入口应先准备迁移,再建表、补迁移并收尾。""" + call_order: List[str] = [] + + def _fake_prepare_database() -> DatabaseMigrationState: + """返回测试用迁移状态。 + + Returns: + DatabaseMigrationState: 不包含迁移步骤的测试状态。 + """ + call_order.append("prepare_database") + return DatabaseMigrationState( + resolved_version=ResolvedSchemaVersion(version=0, source=SchemaVersionSource.EMPTY_DATABASE), + target_version=LATEST_SCHEMA_VERSION, + plan=MigrationPlan( + current_version=EMPTY_SCHEMA_VERSION, + target_version=LATEST_SCHEMA_VERSION, + steps=[], + ), + ) + + def _fake_create_all(bind) -> None: + """记录建表调用。 + + Args: + bind: 传入的数据库绑定对象。 + """ + del bind + call_order.append("create_all") + + def _fake_migrate_action_records() -> None: + """记录轻量补迁移调用。""" + call_order.append("migrate_action_records") + + def _fake_finalize_database(migration_state: DatabaseMigrationState) -> None: + """记录迁移收尾调用。 + + Args: + migration_state: 当前数据库迁移状态。 + """ + del migration_state + call_order.append("finalize_database") + + monkeypatch.setattr(database_module, "_db_initialized", False) + monkeypatch.setattr(database_module, "_DB_DIR", tmp_path / "data") + monkeypatch.setattr(database_module._migration_bootstrapper, "prepare_database", _fake_prepare_database) + monkeypatch.setattr(database_module._migration_bootstrapper, "finalize_database", _fake_finalize_database) + monkeypatch.setattr(database_module.SQLModel.metadata, "create_all", _fake_create_all) + monkeypatch.setattr(database_module, "_migrate_action_records_to_tool_records", _fake_migrate_action_records) + + database_module.initialize_database() + + assert call_order == [ + "prepare_database", + "create_all", + "migrate_action_records", + "finalize_database", + ] diff --git a/src/common/database/database.py b/src/common/database/database.py index 293b47d2..2b22475a 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -1,18 +1,23 @@ -from rich.traceback import install from contextlib import contextmanager from pathlib import Path -from typing import Generator, TYPE_CHECKING +from typing import ContextManager, Generator, TYPE_CHECKING +from rich.traceback import install from sqlalchemy import event, text from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from sqlmodel import SQLModel, Session, create_engine +from src.common.database.migrations import create_database_migration_bootstrapper +from src.common.logger import get_logger + if TYPE_CHECKING: from sqlite3 import Connection as SQLite3Connection install(extra_lines=3) +logger = get_logger("database") + # 定义数据库文件路径 ROOT_PATH = Path(__file__).parent.parent.parent.parent.absolute().resolve() @@ -53,6 +58,7 @@ SessionLocal = sessionmaker( bind=engine, class_=Session, ) +_migration_bootstrapper = create_database_migration_bootstrapper(engine) _db_initialized = False @@ -93,14 +99,29 @@ def _migrate_action_records_to_tool_records() -> None: def initialize_database() -> None: + """初始化数据库连接、结构与启动期迁移。 + + 当前初始化流程遵循以下顺序: + 1. 确保数据库目录存在; + 2. 加载 SQLModel 模型定义; + 3. 执行已注册的启动期迁移; + 4. 兜底执行 ``create_all`` 确保当前模型定义已建表; + 5. 执行项目现有的轻量数据补迁移逻辑。 + """ global _db_initialized if _db_initialized: return _DB_DIR.mkdir(parents=True, exist_ok=True) import src.common.database.database_model # noqa: F401 + migration_state = _migration_bootstrapper.prepare_database() + logger.info( + "数据库迁移准备完成," + f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}" + ) SQLModel.metadata.create_all(engine) _migrate_action_records_to_tool_records() + _migration_bootstrapper.finalize_database(migration_state) _db_initialized = True @@ -150,8 +171,12 @@ def get_db_session(auto_commit: bool = True) -> Generator[Session, None, None]: session.close() -def get_db_session_manual(): - """获取数据库会话的上下文管理器 (手动提交模式)。""" +def get_db_session_manual() -> ContextManager[Session]: + """获取数据库会话的上下文管理器 (手动提交模式)。 + + Returns: + ContextManager[Session]: 手动提交模式的数据库会话上下文管理器。 + """ return get_db_session(auto_commit=False) diff --git a/src/common/database/migrations/__init__.py b/src/common/database/migrations/__init__.py new file mode 100644 index 00000000..e9a69bd1 --- /dev/null +++ b/src/common/database/migrations/__init__.py @@ -0,0 +1,79 @@ +"""数据库迁移基础设施导出模块。""" + +from .bootstrap import DatabaseMigrationBootstrapper, create_database_migration_bootstrapper +from .builtin import ( + EMPTY_SCHEMA_VERSION, + LATEST_SCHEMA_VERSION, + LEGACY_V1_SCHEMA_VERSION, + build_default_migration_registry, + build_default_schema_version_resolver, +) +from .exceptions import ( + DatabaseMigrationConfigurationError, + DatabaseMigrationError, + DatabaseMigrationExecutionError, + DatabaseMigrationPlanningError, + DatabaseMigrationVersionError, + MissingMigrationStepError, + UnrecognizedDatabaseSchemaError, + UnsupportedMigrationDirectionError, +) +from .manager import DatabaseMigrationManager +from .models import ( + ColumnSchema, + DatabaseMigrationState, + DatabaseSchemaSnapshot, + MigrationExecutionContext, + MigrationPlan, + MigrationStep, + ResolvedSchemaVersion, + SchemaVersionSource, + TableSchema, +) +from .planner import MigrationPlanner +from .progress import ( + BaseMigrationProgressReporter, + RichMigrationProgressReporter, + create_rich_migration_progress_reporter, +) +from .registry import MigrationRegistry +from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver +from .schema import SQLiteSchemaInspector +from .version_store import SQLiteUserVersionStore + +__all__ = [ + "BaseSchemaVersionDetector", + "BaseMigrationProgressReporter", + "build_default_migration_registry", + "build_default_schema_version_resolver", + "ColumnSchema", + "create_database_migration_bootstrapper", + "create_rich_migration_progress_reporter", + "DatabaseMigrationConfigurationError", + "DatabaseMigrationError", + "DatabaseMigrationBootstrapper", + "DatabaseMigrationExecutionError", + "DatabaseMigrationManager", + "DatabaseMigrationPlanningError", + "DatabaseMigrationState", + "DatabaseMigrationVersionError", + "DatabaseSchemaSnapshot", + "EMPTY_SCHEMA_VERSION", + "LATEST_SCHEMA_VERSION", + "LEGACY_V1_SCHEMA_VERSION", + "MigrationExecutionContext", + "MigrationPlan", + "MigrationPlanner", + "MigrationRegistry", + "MigrationStep", + "MissingMigrationStepError", + "ResolvedSchemaVersion", + "RichMigrationProgressReporter", + "SchemaVersionResolver", + "SchemaVersionSource", + "SQLiteSchemaInspector", + "SQLiteUserVersionStore", + "TableSchema", + "UnrecognizedDatabaseSchemaError", + "UnsupportedMigrationDirectionError", +] diff --git a/src/common/database/migrations/bootstrap.py b/src/common/database/migrations/bootstrap.py new file mode 100644 index 00000000..a7a0a779 --- /dev/null +++ b/src/common/database/migrations/bootstrap.py @@ -0,0 +1,171 @@ +"""数据库迁移启动桥接层。""" + +from typing import Optional + +from sqlalchemy.engine import Engine + +from src.common.logger import get_logger + +from .builtin import ( + LATEST_SCHEMA_VERSION, + build_default_migration_registry, + build_default_schema_version_resolver, +) +from .exceptions import DatabaseMigrationExecutionError +from .manager import DatabaseMigrationManager +from .models import DatabaseMigrationState, MigrationPlan, ResolvedSchemaVersion, SchemaVersionSource +from .registry import MigrationRegistry +from .resolver import SchemaVersionResolver +from .version_store import SQLiteUserVersionStore + +logger = get_logger("database_migration") + + +class DatabaseMigrationBootstrapper: + """数据库迁移启动桥接器。 + + 该桥接器负责把数据库迁移基础设施接入现有启动流程,同时保持如下约束: + 1. 若数据库为空,则直接交给当前模型定义建出最新结构; + 2. 若数据库版本高于当前代码支持的最新版本,则立即终止启动; + 3. 若存在待执行迁移步骤,则在正常建表流程之前先执行迁移; + 4. 若数据库已是最新结构但尚未写入 ``user_version``,则在建表后补写版本号。 + """ + + def __init__( + self, + manager: DatabaseMigrationManager, + latest_schema_version: int = LATEST_SCHEMA_VERSION, + ) -> None: + """初始化数据库迁移启动桥接器。 + + Args: + manager: 数据库迁移编排器。 + latest_schema_version: 当前代码支持的最新 schema 版本号。 + """ + self.manager = manager + self.latest_schema_version = latest_schema_version + + def prepare_database(self) -> DatabaseMigrationState: + """为数据库初始化阶段准备迁移状态。 + + Returns: + DatabaseMigrationState: 迁移准备完成后的数据库状态。 + + Raises: + DatabaseMigrationExecutionError: 当数据库版本高于当前代码支持版本时抛出。 + """ + with self.manager.engine.connect() as connection: + resolved_version = self.manager.resolver.resolve(connection) + + if resolved_version.version > self.latest_schema_version: + raise DatabaseMigrationExecutionError( + "当前数据库版本高于代码内注册的最新迁移版本,已拒绝继续启动。" + f" 数据库版本={resolved_version.version},代码支持版本={self.latest_schema_version}" + ) + + if resolved_version.source == SchemaVersionSource.EMPTY_DATABASE: + logger.info( + "检测到空数据库,将直接根据当前模型创建最新结构。" + f" 目标版本={self.latest_schema_version}" + ) + return self._build_noop_state( + current_version=resolved_version.version, + target_version=self.latest_schema_version, + resolved_state=resolved_version, + ) + + migration_state = self.manager.describe_state(target_version=self.latest_schema_version) + if not migration_state.requires_migration(): + logger.info( + f"数据库 schema 已是目标版本,无需迁移。当前版本={migration_state.resolved_version.version}" + ) + return migration_state + + logger.info( + "检测到数据库需要迁移," + f" 当前版本={migration_state.resolved_version.version},目标版本={migration_state.target_version}" + ) + self.manager.migrate(target_version=self.latest_schema_version) + return self.manager.describe_state(target_version=self.latest_schema_version) + + def finalize_database(self, migration_state: DatabaseMigrationState) -> None: + """在数据库初始化末尾补写最终 schema 版本号。 + + 该方法主要负责两类场景: + 1. 空库首次建表完成后,将 ``user_version`` 写入为最新版本; + 2. 已是最新结构但此前未写入 ``user_version`` 的数据库,补写版本号。 + + Args: + migration_state: 初始化前解析得到的迁移状态。 + """ + if migration_state.requires_migration(): + return + if migration_state.target_version <= 0: + return + if migration_state.resolved_version.source == SchemaVersionSource.PRAGMA: + return + + with self.manager.engine.begin() as connection: + self.manager.version_store.write_version(connection, migration_state.target_version) + + logger.info( + "数据库 schema 版本写入完成。" + f" 来源={migration_state.resolved_version.source.value}," + f" 写入版本={migration_state.target_version}" + ) + + def _build_noop_state( + self, + current_version: int, + target_version: int, + resolved_state: ResolvedSchemaVersion, + ) -> DatabaseMigrationState: + """构建无迁移动作的数据库状态对象。 + + Args: + current_version: 当前数据库版本号。 + target_version: 当前初始化流程期望达到的目标版本号。 + resolved_state: 已解析的数据库版本状态。 + + Returns: + DatabaseMigrationState: 不包含迁移步骤的状态对象。 + """ + return DatabaseMigrationState( + resolved_version=resolved_state, + target_version=target_version, + plan=MigrationPlan(current_version=current_version, target_version=target_version, steps=[]), + ) + + +def create_database_migration_bootstrapper( + engine: Engine, + registry: Optional[MigrationRegistry] = None, + resolver: Optional[SchemaVersionResolver] = None, + version_store: Optional[SQLiteUserVersionStore] = None, + latest_schema_version: int = LATEST_SCHEMA_VERSION, +) -> DatabaseMigrationBootstrapper: + """创建数据库迁移启动桥接器。 + + Args: + engine: 目标数据库引擎。 + registry: 迁移步骤注册表;未提供时使用默认注册表。 + resolver: 数据库版本解析器;未提供时使用默认解析器。 + version_store: 版本存储器;未提供时使用默认存储器。 + latest_schema_version: 当前代码支持的最新 schema 版本号。 + + Returns: + DatabaseMigrationBootstrapper: 配置完成的数据库迁移启动桥接器。 + """ + migration_registry = registry or build_default_migration_registry() + migration_resolver = resolver or build_default_schema_version_resolver() + migration_version_store = version_store or SQLiteUserVersionStore() + migration_manager = DatabaseMigrationManager( + engine=engine, + registry=migration_registry, + resolver=migration_resolver, + version_store=migration_version_store, + ) + return DatabaseMigrationBootstrapper( + manager=migration_manager, + latest_schema_version=latest_schema_version, + ) diff --git a/src/common/database/migrations/builtin.py b/src/common/database/migrations/builtin.py new file mode 100644 index 00000000..5b16780b --- /dev/null +++ b/src/common/database/migrations/builtin.py @@ -0,0 +1,159 @@ +"""数据库迁移内置版本与默认注册表。""" + +from typing import List, Optional + +from .legacy_v1_to_v2 import migrate_legacy_v1_to_v2 +from .models import DatabaseSchemaSnapshot, MigrationStep +from .registry import MigrationRegistry +from .resolver import BaseSchemaVersionDetector, SchemaVersionResolver +from .version_store import SQLiteUserVersionStore +from .schema import SQLiteSchemaInspector + +EMPTY_SCHEMA_VERSION = 0 +LEGACY_V1_SCHEMA_VERSION = 1 +LATEST_SCHEMA_VERSION = 2 + +_LEGACY_V1_EXCLUSIVE_TABLES = ( + "chat_streams", + "emoji", + "emoji_description_cache", + "expression", + "group_info", + "image_descriptions", + "jargon", + "messages", + "thinking_back", +) + + +class LatestSchemaVersionDetector(BaseSchemaVersionDetector): + """当前最新 schema 结构探测器。""" + + @property + def name(self) -> str: + """返回探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + return "latest_schema_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """检测数据库是否已经是当前最新结构。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别为最新结构则返回最新版本号,否则返回 ``None``。 + """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): + return None + + latest_marker_tables = ( + "mai_messages", + "chat_sessions", + "expressions", + "jargons", + "thinking_questions", + "tool_records", + ) + if not all(snapshot.has_table(table_name) for table_name in latest_marker_tables): + return None + if not snapshot.has_column("images", "image_hash"): + return None + if not snapshot.has_column("images", "full_path"): + return None + if not snapshot.has_column("images", "image_type"): + return None + if not snapshot.has_column("action_records", "session_id"): + return None + if not snapshot.has_column("chat_history", "session_id"): + return None + if not snapshot.has_column("person_info", "user_nickname"): + return None + return LATEST_SCHEMA_VERSION + + +class LegacyV1SchemaDetector(BaseSchemaVersionDetector): + """旧版 ``0.x`` schema 结构探测器。""" + + @property + def name(self) -> str: + """返回探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + return "legacy_v1_schema_detector" + + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """检测数据库是否为旧版 ``0.x`` 结构。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别为旧版结构则返回 ``1``,否则返回 ``None``。 + """ + if any(snapshot.has_table(table_name) for table_name in _LEGACY_V1_EXCLUSIVE_TABLES): + return LEGACY_V1_SCHEMA_VERSION + + legacy_shared_markers = ( + ("action_records", ("chat_id", "time")), + ("chat_history", ("chat_id", "original_text")), + ("images", ("emoji_hash", "path", "type")), + ("llm_usage", ("model_api_provider", "status")), + ("online_time", ("duration",)), + ("person_info", ("nickname", "group_nick_name")), + ) + for table_name, required_columns in legacy_shared_markers: + if snapshot.has_table(table_name) and all( + snapshot.has_column(table_name, column_name) for column_name in required_columns + ): + return LEGACY_V1_SCHEMA_VERSION + return None + + +def build_default_schema_version_detectors() -> List[BaseSchemaVersionDetector]: + """构建默认 schema 版本探测器链。 + + Returns: + List[BaseSchemaVersionDetector]: 按优先级排序的探测器列表。 + """ + return [ + LatestSchemaVersionDetector(), + LegacyV1SchemaDetector(), + ] + + +def build_default_schema_version_resolver() -> SchemaVersionResolver: + """构建默认 schema 版本解析器。 + + Returns: + SchemaVersionResolver: 配置完成的 schema 版本解析器。 + """ + return SchemaVersionResolver( + version_store=SQLiteUserVersionStore(), + schema_inspector=SQLiteSchemaInspector(), + detectors=build_default_schema_version_detectors(), + ) + + +def build_default_migration_registry() -> MigrationRegistry: + """构建默认迁移步骤注册表。 + + Returns: + MigrationRegistry: 含默认迁移步骤的注册表实例。 + """ + return MigrationRegistry( + steps=[ + MigrationStep( + version_from=LEGACY_V1_SCHEMA_VERSION, + version_to=LATEST_SCHEMA_VERSION, + name="legacy_v1_to_latest_v2", + description="将旧版 0.x 数据库整体迁移到当前最新 schema。", + handler=migrate_legacy_v1_to_v2, + ) + ] + ) diff --git a/src/common/database/migrations/exceptions.py b/src/common/database/migrations/exceptions.py new file mode 100644 index 00000000..7f0a667d --- /dev/null +++ b/src/common/database/migrations/exceptions.py @@ -0,0 +1,33 @@ +"""数据库迁移基础设施异常定义。""" + + +class DatabaseMigrationError(Exception): + """数据库迁移基础异常。""" + + +class DatabaseMigrationConfigurationError(DatabaseMigrationError): + """数据库迁移配置不合法。""" + + +class DatabaseMigrationPlanningError(DatabaseMigrationError): + """数据库迁移计划生成失败。""" + + +class DatabaseMigrationExecutionError(DatabaseMigrationError): + """数据库迁移执行失败。""" + + +class DatabaseMigrationVersionError(DatabaseMigrationError): + """数据库版本读写或校验失败。""" + + +class MissingMigrationStepError(DatabaseMigrationPlanningError): + """缺少某个版本区间所需的迁移步骤。""" + + +class UnsupportedMigrationDirectionError(DatabaseMigrationPlanningError): + """当前迁移方向不被支持。""" + + +class UnrecognizedDatabaseSchemaError(DatabaseMigrationVersionError): + """无法识别未标记版本数据库的结构。""" diff --git a/src/common/database/migrations/legacy_v1_to_v2.py b/src/common/database/migrations/legacy_v1_to_v2.py new file mode 100644 index 00000000..284da330 --- /dev/null +++ b/src/common/database/migrations/legacy_v1_to_v2.py @@ -0,0 +1,1384 @@ +"""旧版 ``0.x`` 数据库升级到最新 schema 的迁移逻辑。""" + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, cast + +from sqlalchemy import text +from sqlalchemy.engine import Connection + +import json +import msgpack + +from src.common.logger import get_logger + +from .exceptions import DatabaseMigrationExecutionError +from .models import DatabaseSchemaSnapshot, MigrationExecutionContext +from .schema import SQLiteSchemaInspector + +logger = get_logger("database_migration") + +_LEGACY_V1_BACKUP_PREFIX = "__legacy_v1_" +_LEGACY_V1_TABLE_NAMES = ( + "action_records", + "chat_history", + "chat_streams", + "emoji", + "emoji_description_cache", + "expression", + "group_info", + "image_descriptions", + "images", + "jargon", + "llm_usage", + "messages", + "online_time", + "person_info", + "thinking_back", +) +_EMPTY_MESSAGE_SEQUENCE_BYTES = msgpack.packb([], use_bin_type=True) + + +@dataclass(frozen=True) +class LegacyTableData: + """旧版表数据快照。""" + + source_table_name: str + columns: Set[str] + rows: List[Dict[str, Any]] + + +def migrate_legacy_v1_to_v2(context: MigrationExecutionContext) -> None: + """执行旧版 ``0.x`` 数据库到最新 schema 的迁移。 + + Args: + context: 当前迁移步骤执行上下文。 + """ + from sqlmodel import SQLModel + + import src.common.database.database_model # noqa: F401 + + schema_inspector = SQLiteSchemaInspector() + snapshot = schema_inspector.inspect(context.connection) + _rename_legacy_v1_tables(context.connection, snapshot) + SQLModel.metadata.create_all(context.connection) + + table_migration_jobs: List[Tuple[str, Callable[[Connection], int]]] = [ + ("chat_sessions", _migrate_chat_sessions), + ("llm_usage", _migrate_model_usage), + ("images", _migrate_images), + ("mai_messages", _migrate_messages), + ("action_records", _migrate_action_records), + ("tool_records", _migrate_tool_records), + ("online_time", _migrate_online_time), + ("person_info", _migrate_person_info), + ("expressions", _migrate_expressions), + ("jargons", _migrate_jargons), + ("chat_history", _migrate_chat_history), + ("thinking_questions", _migrate_thinking_questions), + ] + migrated_counts: Dict[str, int] = {} + context.start_progress(total=len(table_migration_jobs), description="总迁移进度", unit_name="表") + for table_name, migration_handler in table_migration_jobs: + migrated_counts[table_name] = migration_handler(context.connection) + context.advance_progress(item_name=table_name) + + summary_text = ", ".join(f"{table_name}={count}" for table_name, count in migrated_counts.items()) + logger.info(f"旧版数据库迁移完成: {summary_text}") + + +def _legacy_backup_table_name(table_name: str) -> str: + """构建旧版表的备份表名。 + + Args: + table_name: 旧版原始表名。 + + Returns: + str: 带前缀的备份表名。 + """ + return f"{_LEGACY_V1_BACKUP_PREFIX}{table_name}" + + +def _quote_identifier(identifier: str) -> str: + """为 SQLite 标识符添加安全引号。 + + Args: + identifier: 待引用的标识符。 + + Returns: + str: 可安全拼接到 SQL 中的标识符。 + """ + escaped_identifier = identifier.replace('"', '""') + return f'"{escaped_identifier}"' + + +def _rename_legacy_v1_tables(connection: Connection, snapshot: DatabaseSchemaSnapshot) -> None: + """将旧版表统一改名为带备份前缀的表名。 + + Args: + connection: 当前数据库连接。 + snapshot: 当前数据库结构快照。 + + Raises: + DatabaseMigrationExecutionError: 当发现同名旧表与备份表同时存在时抛出。 + """ + for table_name in _LEGACY_V1_TABLE_NAMES: + if not snapshot.has_table(table_name): + continue + backup_table_name = _legacy_backup_table_name(table_name) + if snapshot.has_table(backup_table_name): + raise DatabaseMigrationExecutionError( + "检测到旧版表与迁移备份表同时存在,无法安全继续迁移。" + f" 冲突表={table_name},备份表={backup_table_name}" + ) + connection.execute( + text( + f"ALTER TABLE {_quote_identifier(table_name)} " + f"RENAME TO {_quote_identifier(backup_table_name)}" + ) + ) + + +def _load_legacy_table_data(connection: Connection, original_table_name: str) -> Optional[LegacyTableData]: + """加载单张旧版备份表的数据快照。 + + Args: + connection: 当前数据库连接。 + original_table_name: 旧版原始表名。 + + Returns: + Optional[LegacyTableData]: 若备份表存在则返回其数据快照,否则返回 ``None``。 + """ + backup_table_name = _legacy_backup_table_name(original_table_name) + schema_inspector = SQLiteSchemaInspector() + if not schema_inspector.table_exists(connection, backup_table_name): + return None + + table_schema = schema_inspector.get_table_schema(connection, backup_table_name) + rows = connection.execute(text(f"SELECT * FROM {_quote_identifier(backup_table_name)}")).mappings().all() + return LegacyTableData( + source_table_name=backup_table_name, + columns=set(table_schema.columns), + rows=[dict(row) for row in rows], + ) + + +def _normalize_optional_text(value: Any) -> Optional[str]: + """将任意值标准化为可空字符串。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 标准化后的文本;若值为空则返回 ``None``。 + """ + if value is None: + return None + text_value = str(value).strip() + return text_value or None + + +def _normalize_required_text(value: Any, default: str = "") -> str: + """将任意值标准化为非空字符串。 + + Args: + value: 待标准化的原始值。 + default: 为空时使用的默认值。 + + Returns: + str: 标准化后的字符串。 + """ + normalized_value = _normalize_optional_text(value) + if normalized_value is None: + return default + return normalized_value + + +def _normalize_int(value: Any, default: int = 0) -> int: + """将任意值标准化为整数。 + + Args: + value: 待标准化的原始值。 + default: 转换失败时的默认值。 + + Returns: + int: 标准化后的整数。 + """ + if value is None or value == "": + return default + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _normalize_float(value: Any, default: float = 0.0) -> float: + """将任意值标准化为浮点数。 + + Args: + value: 待标准化的原始值。 + default: 转换失败时的默认值。 + + Returns: + float: 标准化后的浮点数。 + """ + if value is None or value == "": + return default + try: + return float(value) + except (TypeError, ValueError): + return default + + +def _normalize_optional_bool(value: Any) -> Optional[bool]: + """将任意值标准化为可空布尔值。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[bool]: 标准化后的布尔值;若无法确定则返回 ``None``。 + """ + if value is None: + return None + if isinstance(value, bool): + return value + if isinstance(value, (int, float)): + return bool(int(value)) + + normalized_text = str(value).strip().lower() + if normalized_text in {"", "null", "none"}: + return None + if normalized_text in {"1", "true", "t", "yes", "y"}: + return True + if normalized_text in {"0", "false", "f", "no", "n"}: + return False + return None + + +def _normalize_bool(value: Any, default: bool = False) -> bool: + """将任意值标准化为布尔值。 + + Args: + value: 待标准化的原始值。 + default: 无法识别时的默认值。 + + Returns: + bool: 标准化后的布尔值。 + """ + parsed_value = _normalize_optional_bool(value) + return default if parsed_value is None else parsed_value + + +def _coerce_datetime(value: Any, fallback_now: bool = False) -> Optional[datetime]: + """将旧版时间字段标准化为 ``datetime``。 + + Args: + value: 待转换的原始值。 + fallback_now: 转换失败时是否回退到当前时间。 + + Returns: + Optional[datetime]: 转换后的时间对象。 + """ + if value is None or value == "": + return datetime.now() if fallback_now else None + if isinstance(value, datetime): + return value + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value)) + except (OSError, OverflowError, ValueError): + return datetime.now() if fallback_now else None + + normalized_text = str(value).strip() + if not normalized_text: + return datetime.now() if fallback_now else None + try: + return datetime.fromtimestamp(float(normalized_text)) + except (TypeError, ValueError, OSError, OverflowError): + pass + try: + return datetime.fromisoformat(normalized_text.replace("Z", "+00:00")) + except ValueError: + return datetime.now() if fallback_now else None + + +def _normalize_string_list(value: Any) -> List[str]: + """将旧版文本或 JSON 字段规范化为字符串列表。 + + Args: + value: 待标准化的原始值。 + + Returns: + List[str]: 规范化后的字符串列表。 + """ + if value is None: + return [] + if isinstance(value, list): + return [str(item).strip() for item in value if str(item).strip()] + + normalized_text = str(value).strip() + if not normalized_text: + return [] + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return [normalized_text] + + if isinstance(parsed_value, list): + return [str(item).strip() for item in parsed_value if str(item).strip()] + if isinstance(parsed_value, str): + parsed_text = parsed_value.strip() + return [parsed_text] if parsed_text else [] + if parsed_value is None: + return [] + return [str(parsed_value).strip()] + + +def _normalize_json_dict_text(value: Any) -> Optional[str]: + """将旧版附加配置标准化为 JSON 字典字符串。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 合法的 JSON 字典字符串;若无内容则返回 ``None``。 + """ + if value is None: + return None + if isinstance(value, dict): + return json.dumps(value, ensure_ascii=False) + + normalized_text = str(value).strip() + if not normalized_text: + return None + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return json.dumps({"_legacy_additional_config_raw": normalized_text}, ensure_ascii=False) + + if isinstance(parsed_value, dict): + return json.dumps(parsed_value, ensure_ascii=False) + return json.dumps({"_legacy_additional_config_raw": parsed_value}, ensure_ascii=False) + + +def _normalize_group_cardname_json(value: Any) -> Optional[str]: + """将旧版群昵称字段转换为当前使用的 JSON 结构。 + + Args: + value: 旧版 ``group_nick_name`` 字段值。 + + Returns: + Optional[str]: 新版 ``group_cardname`` JSON 字符串。 + """ + if value is None: + return None + + normalized_text = str(value).strip() + if not normalized_text: + return None + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return None + + if not isinstance(parsed_value, list): + return None + + normalized_items: List[Dict[str, str]] = [] + for item in parsed_value: + if not isinstance(item, Mapping): + continue + group_id = _normalize_required_text(item.get("group_id")) + group_cardname = _normalize_required_text(item.get("group_cardname") or item.get("group_nick_name")) + if not group_id or not group_cardname: + continue + normalized_items.append( + { + "group_id": group_id, + "group_cardname": group_cardname, + } + ) + if not normalized_items: + return None + return json.dumps(normalized_items, ensure_ascii=False) + + +def _normalize_modified_by(value: Any) -> Optional[str]: + """将旧版审核来源字段标准化为当前枚举名称。 + + Args: + value: 待标准化的原始值。 + + Returns: + Optional[str]: 若能识别则返回 ``AI`` / ``USER``,否则返回 ``None``。 + """ + normalized_text = _normalize_required_text(value).lower() + if normalized_text in {"", "null", "none"}: + return None + if normalized_text in {"ai"}: + return "AI" + if normalized_text in {"user"}: + return "USER" + return None + + +def _build_session_id_dict(value: Any, fallback_count: int) -> str: + """将旧版 ``chat_id`` 字段转换为新版 ``session_id_dict``。 + + Args: + value: 旧版 ``chat_id`` 字段值。 + fallback_count: 默认引用次数。 + + Returns: + str: 新版 ``session_id_dict`` JSON 字符串。 + """ + if value is None: + return json.dumps({}, ensure_ascii=False) + + normalized_text = str(value).strip() + if not normalized_text: + return json.dumps({}, ensure_ascii=False) + try: + parsed_value = json.loads(normalized_text) + except json.JSONDecodeError: + return json.dumps({normalized_text: max(fallback_count, 1)}, ensure_ascii=False) + + if isinstance(parsed_value, str): + parsed_text = parsed_value.strip() + if not parsed_text: + return json.dumps({}, ensure_ascii=False) + return json.dumps({parsed_text: max(fallback_count, 1)}, ensure_ascii=False) + if not isinstance(parsed_value, list): + return json.dumps({}, ensure_ascii=False) + + session_counts: Dict[str, int] = {} + for item in parsed_value: + if not isinstance(item, list) or not item: + continue + session_id = _normalize_required_text(item[0]) + if not session_id: + continue + session_count = fallback_count + if len(item) > 1: + session_count = _normalize_int(item[1], default=fallback_count) + session_counts[session_id] = max(session_count, 1) + return json.dumps(session_counts, ensure_ascii=False) + + +def _build_legacy_message_additional_config(row: Mapping[str, Any]) -> Optional[str]: + """构建新版消息表使用的附加配置 JSON。 + + Args: + row: 旧版消息表行数据。 + + Returns: + Optional[str]: 新版消息表 ``additional_config`` 字段内容。 + """ + additional_config_text = _normalize_json_dict_text(row.get("additional_config")) + if additional_config_text: + merged_config = json.loads(additional_config_text) + else: + merged_config = {} + + legacy_fields = { + "intercept_message_level": row.get("intercept_message_level"), + "interest_value": row.get("interest_value"), + "key_words": row.get("key_words"), + "key_words_lite": row.get("key_words_lite"), + "priority_info": row.get("priority_info"), + "priority_mode": row.get("priority_mode"), + "selected_expressions": row.get("selected_expressions"), + } + for field_name, field_value in legacy_fields.items(): + if field_value is None: + continue + merged_config[field_name] = field_value + + if not merged_config: + return None + return json.dumps(merged_config, ensure_ascii=False) + + +def _build_message_raw_content(processed_plain_text: Optional[str], display_message: Optional[str]) -> bytes: + """为旧版消息构造一个可被当前代码读取的占位 ``raw_content``。 + + Args: + processed_plain_text: 旧版消息的处理后文本。 + display_message: 旧版消息的展示文本。 + + Returns: + bytes: 可被当前消息模型安全反序列化的 msgpack 字节串。 + """ + message_text = _normalize_optional_text(display_message) or _normalize_optional_text(processed_plain_text) + if not message_text: + return cast(bytes, _EMPTY_MESSAGE_SEQUENCE_BYTES) + serialized_payload = [{"type": "text", "data": message_text}] + return cast(bytes, msgpack.packb(serialized_payload, use_bin_type=True)) + + +def _deduce_image_type_name(value: Any) -> str: + """将旧版图片类型转换为当前枚举名称。 + + Args: + value: 旧版图片类型字段值。 + + Returns: + str: 当前 ``ImageType`` 枚举在数据库中的文本值。 + """ + normalized_text = _normalize_required_text(value, default="image").lower() + if normalized_text == "emoji": + return "EMOJI" + return "IMAGE" + + +def _migrate_chat_sessions(connection: Connection) -> int: + """迁移旧版 ``chat_streams`` 到新版 ``chat_sessions``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "chat_streams") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO chat_sessions ( + session_id, + created_timestamp, + last_active_timestamp, + user_id, + group_id, + platform + ) VALUES ( + :session_id, + :created_timestamp, + :last_active_timestamp, + :user_id, + :group_id, + :platform + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_required_text(row.get("stream_id")) + if not session_id: + continue + connection.execute( + insert_sql, + { + "session_id": session_id, + "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), + "last_active_timestamp": _coerce_datetime(row.get("last_active_time"), fallback_now=True), + "user_id": _normalize_optional_text(row.get("user_id")), + "group_id": _normalize_optional_text(row.get("group_id")), + "platform": _normalize_required_text(row.get("platform"), default="unknown"), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_model_usage(connection: Connection) -> int: + """迁移旧版 ``llm_usage`` 到新版 ``llm_usage``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "llm_usage") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO llm_usage ( + id, + model_name, + model_assign_name, + model_api_provider_name, + endpoint, + user_type, + request_type, + time_cost, + timestamp, + prompt_tokens, + completion_tokens, + total_tokens, + cost + ) VALUES ( + :id, + :model_name, + :model_assign_name, + :model_api_provider_name, + :endpoint, + :user_type, + :request_type, + :time_cost, + :timestamp, + :prompt_tokens, + :completion_tokens, + :total_tokens, + :cost + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "model_name": _normalize_required_text(row.get("model_name"), default="unknown"), + "model_assign_name": _normalize_optional_text(row.get("model_assign_name")), + "model_api_provider_name": _normalize_required_text(row.get("model_api_provider"), default="unknown"), + "endpoint": _normalize_optional_text(row.get("endpoint")), + "user_type": "SYSTEM", + "request_type": _normalize_required_text(row.get("request_type"), default="unknown"), + "time_cost": _normalize_float(row.get("time_cost"), default=0.0), + "timestamp": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "prompt_tokens": _normalize_int(row.get("prompt_tokens"), default=0), + "completion_tokens": _normalize_int(row.get("completion_tokens"), default=0), + "total_tokens": _normalize_int(row.get("total_tokens"), default=0), + "cost": _normalize_float(row.get("cost"), default=0.0), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_images(connection: Connection) -> int: + """迁移旧版 ``emoji`` 与 ``images`` 到新版 ``images``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + migrated_count = 0 + existing_keys: Set[Tuple[str, str, str]] = set() + existing_rows = connection.execute( + text("SELECT full_path, image_hash, image_type FROM images") + ).mappings().all() + for row in existing_rows: + existing_keys.add( + ( + _normalize_required_text(row.get("full_path")), + _normalize_required_text(row.get("image_hash")), + _normalize_required_text(row.get("image_type")), + ) + ) + insert_sql = text( + """ + INSERT INTO images ( + image_hash, + description, + full_path, + image_type, + emotion, + query_count, + is_registered, + is_banned, + no_file_flag, + record_time, + register_time, + last_used_time, + vlm_processed + ) VALUES ( + :image_hash, + :description, + :full_path, + :image_type, + :emotion, + :query_count, + :is_registered, + :is_banned, + :no_file_flag, + :record_time, + :register_time, + :last_used_time, + :vlm_processed + ) + """ + ) + + legacy_emoji_table = _load_legacy_table_data(connection, "emoji") + if legacy_emoji_table is not None: + for row in legacy_emoji_table.rows: + full_path = _normalize_required_text(row.get("full_path")) + image_hash = _normalize_required_text(row.get("emoji_hash")) + dedupe_key = (full_path, image_hash, "EMOJI") + if not full_path or dedupe_key in existing_keys: + continue + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": "EMOJI", + "emotion": _normalize_optional_text(row.get("emotion")), + "query_count": _normalize_int(row.get("query_count"), default=0), + "is_registered": _normalize_bool(row.get("is_registered"), default=False), + "is_banned": _normalize_bool(row.get("is_banned"), default=False), + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("record_time"), fallback_now=True), + "register_time": _coerce_datetime(row.get("register_time")), + "last_used_time": _coerce_datetime(row.get("last_used_time")), + "vlm_processed": False, + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + + legacy_images_table = _load_legacy_table_data(connection, "images") + if legacy_images_table is not None: + for row in legacy_images_table.rows: + full_path = _normalize_required_text(row.get("path")) + image_hash = _normalize_required_text(row.get("emoji_hash")) + image_type = _deduce_image_type_name(row.get("type")) + dedupe_key = (full_path, image_hash, image_type) + if not full_path or dedupe_key in existing_keys: + continue + connection.execute( + insert_sql, + { + "image_hash": image_hash, + "description": _normalize_required_text(row.get("description")), + "full_path": full_path, + "image_type": image_type, + "emotion": None, + "query_count": _normalize_int(row.get("count"), default=0), + "is_registered": False, + "is_banned": False, + "no_file_flag": False, + "record_time": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "register_time": None, + "last_used_time": None, + "vlm_processed": _normalize_bool(row.get("vlm_processed"), default=False), + }, + ) + existing_keys.add(dedupe_key) + migrated_count += 1 + + return migrated_count + + +def _migrate_messages(connection: Connection) -> int: + """迁移旧版 ``messages`` 到新版 ``mai_messages``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "messages") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO mai_messages ( + id, + message_id, + timestamp, + platform, + user_id, + user_nickname, + user_cardname, + group_id, + group_name, + is_mentioned, + is_at, + session_id, + reply_to, + is_emoji, + is_picture, + is_command, + is_notify, + raw_content, + processed_plain_text, + display_message, + additional_config + ) VALUES ( + :id, + :message_id, + :timestamp, + :platform, + :user_id, + :user_nickname, + :user_cardname, + :group_id, + :group_name, + :is_mentioned, + :is_at, + :session_id, + :reply_to, + :is_emoji, + :is_picture, + :is_command, + :is_notify, + :raw_content, + :processed_plain_text, + :display_message, + :additional_config + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if not session_id: + continue + processed_plain_text = _normalize_optional_text(row.get("processed_plain_text")) + display_message = _normalize_optional_text(row.get("display_message")) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "message_id": _normalize_required_text(row.get("message_id"), default=""), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "platform": _normalize_required_text( + row.get("chat_info_platform") or row.get("user_platform"), + default="unknown", + ), + "user_id": _normalize_required_text( + row.get("user_id") or row.get("chat_info_user_id"), + default="", + ), + "user_nickname": _normalize_required_text( + row.get("user_nickname") or row.get("chat_info_user_nickname"), + default="", + ), + "user_cardname": _normalize_optional_text( + row.get("user_cardname") or row.get("chat_info_user_cardname") + ), + "group_id": _normalize_optional_text(row.get("chat_info_group_id")), + "group_name": _normalize_optional_text(row.get("chat_info_group_name")), + "is_mentioned": _normalize_bool(row.get("is_mentioned"), default=False), + "is_at": _normalize_bool(row.get("is_at"), default=False), + "session_id": session_id, + "reply_to": _normalize_optional_text(row.get("reply_to")), + "is_emoji": _normalize_bool(row.get("is_emoji"), default=False), + "is_picture": _normalize_bool(row.get("is_picid"), default=False), + "is_command": _normalize_bool(row.get("is_command"), default=False), + "is_notify": _normalize_bool(row.get("is_notify"), default=False), + "raw_content": _build_message_raw_content(processed_plain_text, display_message), + "processed_plain_text": processed_plain_text, + "display_message": display_message, + "additional_config": _build_legacy_message_additional_config(row), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_action_records(connection: Connection) -> int: + """迁移旧版 ``action_records`` 到新版 ``action_records``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "action_records") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO action_records ( + id, + action_id, + timestamp, + session_id, + action_name, + action_reasoning, + action_data, + action_builtin_prompt, + action_display_prompt + ) VALUES ( + :id, + :action_id, + :timestamp, + :session_id, + :action_name, + :action_reasoning, + :action_data, + :action_builtin_prompt, + :action_display_prompt + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if not session_id: + continue + connection.execute( + insert_sql, + { + "id": row.get("id"), + "action_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "action_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "action_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "action_data": _normalize_optional_text(row.get("action_data")), + "action_builtin_prompt": None, + "action_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_tool_records(connection: Connection) -> int: + """迁移旧版 ``action_records`` 到新版 ``tool_records``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "action_records") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO tool_records ( + id, + tool_id, + timestamp, + session_id, + tool_name, + tool_reasoning, + tool_data, + tool_builtin_prompt, + tool_display_prompt + ) VALUES ( + :id, + :tool_id, + :timestamp, + :session_id, + :tool_name, + :tool_reasoning, + :tool_data, + :tool_builtin_prompt, + :tool_display_prompt + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_optional_text(row.get("chat_id")) or _normalize_optional_text(row.get("chat_info_stream_id")) + if not session_id: + continue + connection.execute( + insert_sql, + { + "id": row.get("id"), + "tool_id": _normalize_required_text(row.get("action_id")), + "timestamp": _coerce_datetime(row.get("time"), fallback_now=True), + "session_id": session_id, + "tool_name": _normalize_required_text(row.get("action_name"), default="unknown"), + "tool_reasoning": _normalize_optional_text(row.get("action_reasoning")), + "tool_data": _normalize_optional_text(row.get("action_data")), + "tool_builtin_prompt": None, + "tool_display_prompt": _normalize_optional_text(row.get("action_prompt_display")), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_online_time(connection: Connection) -> int: + """迁移旧版 ``online_time`` 到新版 ``online_time``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "online_time") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO online_time ( + id, + timestamp, + duration_minutes, + start_timestamp, + end_timestamp + ) VALUES ( + :id, + :timestamp, + :duration_minutes, + :start_timestamp, + :end_timestamp + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "timestamp": _coerce_datetime(row.get("timestamp"), fallback_now=True), + "duration_minutes": _normalize_int(row.get("duration"), default=0), + "start_timestamp": _coerce_datetime(row.get("start_timestamp"), fallback_now=True), + "end_timestamp": _coerce_datetime(row.get("end_timestamp"), fallback_now=True), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_person_info(connection: Connection) -> int: + """迁移旧版 ``person_info`` 到新版 ``person_info``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "person_info") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO person_info ( + id, + is_known, + person_id, + person_name, + name_reason, + platform, + user_id, + user_nickname, + group_cardname, + memory_points, + know_counts, + first_known_time, + last_known_time + ) VALUES ( + :id, + :is_known, + :person_id, + :person_name, + :name_reason, + :platform, + :user_id, + :user_nickname, + :group_cardname, + :memory_points, + :know_counts, + :first_known_time, + :last_known_time + ) + """ + ) + for row in legacy_table.rows: + first_known_time = _coerce_datetime(row.get("know_times")) or _coerce_datetime(row.get("know_since")) + last_known_time = _coerce_datetime(row.get("last_know")) or _coerce_datetime(row.get("know_since")) + memory_points = _normalize_string_list(row.get("memory_points")) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "is_known": _normalize_bool(row.get("is_known"), default=False), + "person_id": _normalize_required_text(row.get("person_id")), + "person_name": _normalize_optional_text(row.get("person_name")), + "name_reason": _normalize_optional_text(row.get("name_reason")), + "platform": _normalize_required_text(row.get("platform"), default="unknown"), + "user_id": _normalize_required_text(row.get("user_id"), default=""), + "user_nickname": _normalize_required_text(row.get("nickname"), default=""), + "group_cardname": _normalize_group_cardname_json(row.get("group_nick_name")), + "memory_points": json.dumps(memory_points, ensure_ascii=False) if memory_points else None, + "know_counts": 1 if _normalize_bool(row.get("is_known"), default=False) else 0, + "first_known_time": first_known_time, + "last_known_time": last_known_time, + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_expressions(connection: Connection) -> int: + """迁移旧版 ``expression`` 到新版 ``expressions``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "expression") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO expressions ( + id, + situation, + style, + content_list, + count, + last_active_time, + create_time, + session_id, + checked, + rejected, + modified_by + ) VALUES ( + :id, + :situation, + :style, + :content_list, + :count, + :last_active_time, + :create_time, + :session_id, + :checked, + :rejected, + :modified_by + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "situation": _normalize_required_text(row.get("situation"), default=""), + "style": _normalize_required_text(row.get("style"), default=""), + "content_list": json.dumps(_normalize_string_list(row.get("content_list")), ensure_ascii=False), + "count": _normalize_int(row.get("count"), default=1), + "last_active_time": _coerce_datetime(row.get("last_active_time"), fallback_now=True), + "create_time": _coerce_datetime(row.get("create_date"), fallback_now=True), + "session_id": _normalize_optional_text(row.get("chat_id")), + "checked": _normalize_bool(row.get("checked"), default=False), + "rejected": _normalize_bool(row.get("rejected"), default=False), + "modified_by": _normalize_modified_by(row.get("modified_by")), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_jargons(connection: Connection) -> int: + """迁移旧版 ``jargon`` 到新版 ``jargons``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "jargon") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO jargons ( + id, + content, + raw_content, + meaning, + session_id_dict, + count, + is_jargon, + is_complete, + is_global, + last_inference_count, + inference_with_context, + inference_with_content_only + ) VALUES ( + :id, + :content, + :raw_content, + :meaning, + :session_id_dict, + :count, + :is_jargon, + :is_complete, + :is_global, + :last_inference_count, + :inference_with_context, + :inference_with_content_only + ) + """ + ) + for row in legacy_table.rows: + count = _normalize_int(row.get("count"), default=0) + connection.execute( + insert_sql, + { + "id": row.get("id"), + "content": _normalize_required_text(row.get("content"), default=""), + "raw_content": json.dumps(_normalize_string_list(row.get("raw_content")), ensure_ascii=False) + if row.get("raw_content") is not None + else None, + "meaning": _normalize_required_text(row.get("meaning")), + "session_id_dict": _build_session_id_dict(row.get("chat_id"), fallback_count=max(count, 1)), + "count": count, + "is_jargon": _normalize_optional_bool(row.get("is_jargon")), + "is_complete": _normalize_bool(row.get("is_complete"), default=False), + "is_global": _normalize_bool(row.get("is_global"), default=False), + "last_inference_count": _normalize_int(row.get("last_inference_count"), default=0), + "inference_with_context": _normalize_optional_text(row.get("inference_with_context")), + "inference_with_content_only": _normalize_optional_text( + row.get("inference_content_only") or row.get("inference_with_content_only") + ), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_chat_history(connection: Connection) -> int: + """迁移旧版 ``chat_history`` 到新版 ``chat_history``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "chat_history") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO chat_history ( + id, + session_id, + start_timestamp, + end_timestamp, + query_count, + query_forget_count, + original_messages, + participants, + theme, + keywords, + summary + ) VALUES ( + :id, + :session_id, + :start_timestamp, + :end_timestamp, + :query_count, + :query_forget_count, + :original_messages, + :participants, + :theme, + :keywords, + :summary + ) + """ + ) + for row in legacy_table.rows: + session_id = _normalize_required_text(row.get("chat_id")) + if not session_id: + continue + connection.execute( + insert_sql, + { + "id": row.get("id"), + "session_id": session_id, + "start_timestamp": _coerce_datetime(row.get("start_time"), fallback_now=True), + "end_timestamp": _coerce_datetime(row.get("end_time"), fallback_now=True), + "query_count": _normalize_int(row.get("count"), default=0), + "query_forget_count": _normalize_int(row.get("forget_times"), default=0), + "original_messages": _normalize_required_text(row.get("original_text")), + "participants": _normalize_required_text(row.get("participants"), default="[]"), + "theme": _normalize_required_text(row.get("theme"), default=""), + "keywords": _normalize_required_text(row.get("keywords"), default="[]"), + "summary": _normalize_required_text(row.get("summary"), default=""), + }, + ) + migrated_count += 1 + return migrated_count + + +def _migrate_thinking_questions(connection: Connection) -> int: + """迁移旧版 ``thinking_back`` 到新版 ``thinking_questions``。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 迁移成功的记录数。 + """ + legacy_table = _load_legacy_table_data(connection, "thinking_back") + if legacy_table is None: + return 0 + + migrated_count = 0 + insert_sql = text( + """ + INSERT OR IGNORE INTO thinking_questions ( + id, + question, + context, + found_answer, + answer, + thinking_steps, + created_timestamp, + updated_timestamp + ) VALUES ( + :id, + :question, + :context, + :found_answer, + :answer, + :thinking_steps, + :created_timestamp, + :updated_timestamp + ) + """ + ) + for row in legacy_table.rows: + connection.execute( + insert_sql, + { + "id": row.get("id"), + "question": _normalize_required_text(row.get("question"), default=""), + "context": _normalize_optional_text(row.get("context")), + "found_answer": _normalize_bool(row.get("found_answer"), default=False), + "answer": _normalize_optional_text(row.get("answer")), + "thinking_steps": _normalize_optional_text(row.get("thinking_steps")), + "created_timestamp": _coerce_datetime(row.get("create_time"), fallback_now=True), + "updated_timestamp": _coerce_datetime(row.get("update_time"), fallback_now=True), + }, + ) + migrated_count += 1 + return migrated_count diff --git a/src/common/database/migrations/manager.py b/src/common/database/migrations/manager.py new file mode 100644 index 00000000..d33e6926 --- /dev/null +++ b/src/common/database/migrations/manager.py @@ -0,0 +1,205 @@ +"""数据库迁移编排器。""" + +from typing import Callable, Optional + +from sqlalchemy.engine import Connection, Engine + +from src.common.logger import get_logger + +from .exceptions import DatabaseMigrationExecutionError +from .models import DatabaseMigrationState, MigrationExecutionContext, MigrationPlan +from .planner import MigrationPlanner +from .progress import BaseMigrationProgressReporter, create_rich_migration_progress_reporter +from .registry import MigrationRegistry +from .resolver import SchemaVersionResolver +from .version_store import SQLiteUserVersionStore + +logger = get_logger("database_migration") + + +class DatabaseMigrationManager: + """数据库迁移编排器。 + + 该类只负责基础设施层面的编排工作,包括: + 1. 解析当前数据库版本; + 2. 生成迁移计划; + 3. 顺序执行已注册迁移步骤; + 4. 在每一步成功后更新 ``user_version``。 + + 当前模块不内置任何业务迁移步骤,也不会自动接入项目启动流程。 + """ + + def __init__( + self, + engine: Engine, + registry: Optional[MigrationRegistry] = None, + planner: Optional[MigrationPlanner] = None, + resolver: Optional[SchemaVersionResolver] = None, + version_store: Optional[SQLiteUserVersionStore] = None, + progress_reporter_factory: Optional[Callable[[], BaseMigrationProgressReporter]] = None, + ) -> None: + """初始化数据库迁移编排器。 + + Args: + engine: 目标数据库引擎。 + registry: 迁移步骤注册表。 + planner: 迁移计划生成器。 + resolver: 数据库版本解析器。 + version_store: 版本存储器。 + progress_reporter_factory: 迁移进度上报器工厂。 + """ + self.engine = engine + self.registry = registry or MigrationRegistry() + self.planner = planner or MigrationPlanner() + self.resolver = resolver or SchemaVersionResolver() + self.version_store = version_store or SQLiteUserVersionStore() + self.progress_reporter_factory = progress_reporter_factory or create_rich_migration_progress_reporter + + def describe_state(self, target_version: Optional[int] = None) -> DatabaseMigrationState: + """描述当前数据库的迁移状态。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + DatabaseMigrationState: 当前数据库迁移状态。 + """ + with self.engine.connect() as connection: + resolved_version = self.resolver.resolve(connection) + + effective_target_version = self._resolve_target_version(target_version) + migration_plan = self.planner.plan( + current_version=resolved_version.version, + target_version=effective_target_version, + registry=self.registry, + ) + return DatabaseMigrationState( + resolved_version=resolved_version, + target_version=effective_target_version, + plan=migration_plan, + ) + + def plan(self, target_version: Optional[int] = None) -> MigrationPlan: + """生成当前数据库的迁移计划。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + MigrationPlan: 当前数据库对应的迁移计划。 + """ + return self.describe_state(target_version=target_version).plan + + def migrate(self, target_version: Optional[int] = None) -> MigrationPlan: + """执行迁移计划。 + + 注意: + 若当前数据库是通过结构探测得出的版本,且计划为空,本方法不会自动把该 + 版本写回 ``user_version``。这样做是为了避免在尚未明确接入策略前引入隐式 + 副作用。 + + Args: + target_version: 目标数据库版本;未提供时使用注册表中声明的最新版本。 + + Returns: + MigrationPlan: 已执行的迁移计划。 + + Raises: + DatabaseMigrationExecutionError: 当迁移步骤执行失败时抛出。 + """ + migration_state = self.describe_state(target_version=target_version) + migration_plan = migration_state.plan + if migration_plan.is_empty(): + logger.info("数据库迁移计划为空,跳过执行。") + return migration_plan + + current_version = migration_state.resolved_version.version + total_steps = migration_plan.step_count() + for step_index, step in enumerate(migration_plan.steps, start=1): + logger.info( + f"开始执行数据库迁移步骤: {step.name} ({step.version_from} -> {step.version_to})" + ) + try: + with self.progress_reporter_factory() as progress_reporter: + if step.transactional: + with self.engine.begin() as connection: + execution_context = self._build_execution_context( + connection=connection, + current_version=current_version, + migration_plan=migration_plan, + step_index=step_index, + step_name=step.name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) + step.run(execution_context) + self.version_store.write_version(connection, step.version_to) + else: + with self.engine.connect() as connection: + execution_context = self._build_execution_context( + connection=connection, + current_version=current_version, + migration_plan=migration_plan, + step_index=step_index, + step_name=step.name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) + step.run(execution_context) + self.version_store.write_version(connection, step.version_to) + connection.commit() + except Exception as exc: + raise DatabaseMigrationExecutionError( + f"执行迁移步骤 {step.name} ({step.version_from} -> {step.version_to}) 失败。" + ) from exc + current_version = step.version_to + logger.info(f"数据库迁移步骤执行完成: {step.name},当前版本已更新为 {current_version}") + + return migration_plan + + def _resolve_target_version(self, target_version: Optional[int]) -> int: + """解析最终目标版本号。 + + Args: + target_version: 调用方显式指定的目标版本。 + + Returns: + int: 最终用于规划和执行的目标版本号。 + """ + if target_version is not None: + return target_version + return self.registry.latest_version() + + def _build_execution_context( + self, + connection: Connection, + current_version: int, + migration_plan: MigrationPlan, + step_index: int, + step_name: str, + total_steps: int, + progress_reporter: BaseMigrationProgressReporter, + ) -> MigrationExecutionContext: + """构建单个迁移步骤的执行上下文。 + + Args: + connection: 当前迁移步骤使用的数据库连接。 + current_version: 当前数据库版本。 + migration_plan: 当前迁移计划。 + step_index: 当前步骤序号,从 ``1`` 开始。 + step_name: 当前步骤名称。 + total_steps: 计划总步骤数。 + progress_reporter: 当前步骤使用的进度上报器。 + + Returns: + MigrationExecutionContext: 当前步骤的执行上下文对象。 + """ + return MigrationExecutionContext( + connection=connection, + current_version=current_version, + target_version=migration_plan.target_version, + step_index=step_index, + step_name=step_name, + total_steps=total_steps, + progress_reporter=progress_reporter, + ) diff --git a/src/common/database/migrations/models.py b/src/common/database/migrations/models.py new file mode 100644 index 00000000..bc8cf488 --- /dev/null +++ b/src/common/database/migrations/models.py @@ -0,0 +1,285 @@ +"""数据库迁移基础设施核心数据模型。""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum +from typing import Callable, Dict, List, Optional, TYPE_CHECKING + +from sqlalchemy.engine import Connection + +if TYPE_CHECKING: + from .progress import BaseMigrationProgressReporter + + +def _utc_now() -> datetime: + """返回当前 UTC 时间。 + + Returns: + datetime: 当前 UTC 时间。 + """ + return datetime.now(timezone.utc) + + +class SchemaVersionSource(str, Enum): + """数据库版本来源。""" + + PRAGMA = "pragma" + DETECTOR = "detector" + EMPTY_DATABASE = "empty_database" + + +@dataclass(frozen=True) +class ColumnSchema: + """数据库列结构快照。""" + + name: str + declared_type: str + default_value: Optional[str] + is_not_null: bool + primary_key_position: int + + +@dataclass(frozen=True) +class TableSchema: + """数据库表结构快照。""" + + name: str + columns: Dict[str, ColumnSchema] + + def has_column(self, column_name: str) -> bool: + """判断表中是否存在指定列。 + + Args: + column_name: 待检查的列名。 + + Returns: + bool: 若列存在则返回 ``True``,否则返回 ``False``。 + """ + return column_name in self.columns + + def get_column(self, column_name: str) -> Optional[ColumnSchema]: + """获取指定列的结构信息。 + + Args: + column_name: 待获取的列名。 + + Returns: + Optional[ColumnSchema]: 列存在时返回列结构,否则返回 ``None``。 + """ + return self.columns.get(column_name) + + def column_names(self) -> List[str]: + """返回当前表中全部列名。 + + Returns: + List[str]: 按字母顺序排列的列名列表。 + """ + return sorted(self.columns) + + +@dataclass(frozen=True) +class DatabaseSchemaSnapshot: + """数据库结构快照。""" + + tables: Dict[str, TableSchema] + + def is_empty(self) -> bool: + """判断数据库是否没有任何用户表。 + + Returns: + bool: 若数据库中没有用户表则返回 ``True``。 + """ + return not self.tables + + def has_table(self, table_name: str) -> bool: + """判断数据库是否存在指定表。 + + Args: + table_name: 待检查的表名。 + + Returns: + bool: 若表存在则返回 ``True``,否则返回 ``False``。 + """ + return table_name in self.tables + + def has_column(self, table_name: str, column_name: str) -> bool: + """判断数据库指定表中是否存在指定列。 + + Args: + table_name: 待检查的表名。 + column_name: 待检查的列名。 + + Returns: + bool: 若表和列均存在则返回 ``True``。 + """ + table_schema = self.get_table(table_name) + if table_schema is None: + return False + return table_schema.has_column(column_name) + + def get_table(self, table_name: str) -> Optional[TableSchema]: + """获取指定表的结构信息。 + + Args: + table_name: 待获取的表名。 + + Returns: + Optional[TableSchema]: 表存在时返回对应结构,否则返回 ``None``。 + """ + return self.tables.get(table_name) + + def table_names(self) -> List[str]: + """返回当前数据库中的全部用户表名。 + + Returns: + List[str]: 按字母顺序排列的表名列表。 + """ + return sorted(self.tables) + + +@dataclass(frozen=True) +class ResolvedSchemaVersion: + """解析后的数据库版本信息。""" + + version: int + source: SchemaVersionSource + detector_name: Optional[str] = None + snapshot: Optional[DatabaseSchemaSnapshot] = None + + +@dataclass(frozen=True) +class MigrationExecutionContext: + """单个迁移步骤的执行上下文。""" + + connection: Connection + current_version: int + target_version: int + step_index: int + step_name: str + total_steps: int + started_at: datetime = field(default_factory=_utc_now) + progress_reporter: Optional["BaseMigrationProgressReporter"] = None + + def is_last_step(self) -> bool: + """判断当前步骤是否为最后一步。 + + Returns: + bool: 若当前步骤已是计划中的最后一步则返回 ``True``。 + """ + return self.step_index >= self.total_steps + + def start_progress( + self, + total: int, + description: str = "总迁移进度", + unit_name: str = "表", + ) -> None: + """启动当前迁移步骤的进度展示。 + + Args: + total: 当前步骤需要处理的总项目数。 + description: 进度描述文本。 + unit_name: 进度单位名称。 + """ + if self.progress_reporter is None: + return + self.progress_reporter.start(total=total, description=description, unit_name=unit_name) + + def advance_progress(self, advance: int = 1, item_name: Optional[str] = None) -> None: + """推进当前迁移步骤的进度展示。 + + Args: + advance: 本次推进的步数。 + item_name: 当前完成的项目名称。 + """ + if self.progress_reporter is None: + return + self.progress_reporter.advance(advance=advance, item_name=item_name) + + +MigrationHandler = Callable[[MigrationExecutionContext], None] + + +@dataclass(frozen=True) +class MigrationStep: + """单个数据库迁移步骤定义。""" + + version_from: int + version_to: int + name: str + description: str + handler: MigrationHandler + transactional: bool = True + + def __post_init__(self) -> None: + """校验迁移步骤定义是否合法。 + + Raises: + ValueError: 当版本号不合法或迁移方向错误时抛出。 + """ + if self.version_from < 0: + raise ValueError("迁移起始版本不能小于 0。") + if self.version_to <= self.version_from: + raise ValueError("迁移目标版本必须大于起始版本。") + + def run(self, context: MigrationExecutionContext) -> None: + """执行当前迁移步骤。 + + Args: + context: 当前迁移步骤的执行上下文。 + """ + self.handler(context) + + +@dataclass(frozen=True) +class MigrationPlan: + """数据库迁移执行计划。""" + + current_version: int + target_version: int + steps: List[MigrationStep] + + def is_empty(self) -> bool: + """判断迁移计划是否为空。 + + Returns: + bool: 若无需执行任何迁移步骤则返回 ``True``。 + """ + return not self.steps + + def step_count(self) -> int: + """返回迁移计划中的步骤数量。 + + Returns: + int: 当前计划中的迁移步骤数。 + """ + return len(self.steps) + + def latest_reachable_version(self) -> int: + """返回该计划执行后的最终版本。 + + Returns: + int: 若计划为空则返回当前版本,否则返回最后一步的目标版本。 + """ + if self.is_empty(): + return self.current_version + return self.steps[-1].version_to + + +@dataclass(frozen=True) +class DatabaseMigrationState: + """数据库迁移状态描述。""" + + resolved_version: ResolvedSchemaVersion + target_version: int + plan: MigrationPlan + + def requires_migration(self) -> bool: + """判断当前状态是否需要执行迁移。 + + Returns: + bool: 若计划中存在待执行迁移步骤则返回 ``True``。 + """ + return not self.plan.is_empty() diff --git a/src/common/database/migrations/planner.py b/src/common/database/migrations/planner.py new file mode 100644 index 00000000..eca98c27 --- /dev/null +++ b/src/common/database/migrations/planner.py @@ -0,0 +1,108 @@ +"""数据库迁移计划生成器。""" + +from typing import List + +from .exceptions import ( + DatabaseMigrationPlanningError, + MissingMigrationStepError, + UnsupportedMigrationDirectionError, +) +from .models import MigrationPlan, MigrationStep +from .registry import MigrationRegistry + + +class MigrationPlanner: + """数据库迁移计划生成器。""" + + def plan( + self, + current_version: int, + target_version: int, + registry: MigrationRegistry, + ) -> MigrationPlan: + """根据当前版本与目标版本生成迁移计划。 + + Args: + current_version: 当前数据库版本。 + target_version: 目标数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + MigrationPlan: 按顺序执行的迁移计划。 + + Raises: + DatabaseMigrationPlanningError: 当版本号非法时抛出。 + MissingMigrationStepError: 当所需迁移步骤缺失时抛出。 + UnsupportedMigrationDirectionError: 当请求降级迁移时抛出。 + """ + self._validate_version(current_version, "current_version") + self._validate_version(target_version, "target_version") + + if target_version < current_version: + raise UnsupportedMigrationDirectionError( + f"当前仅支持升级迁移,不支持从 {current_version} 降级到 {target_version}。" + ) + if target_version == current_version: + return MigrationPlan(current_version=current_version, target_version=target_version, steps=[]) + + steps = self._build_steps(current_version, target_version, registry) + return MigrationPlan(current_version=current_version, target_version=target_version, steps=steps) + + def plan_to_latest(self, current_version: int, registry: MigrationRegistry) -> MigrationPlan: + """生成迁移到注册表最新版本的执行计划。 + + Args: + current_version: 当前数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + MigrationPlan: 指向最新版本的迁移计划。 + """ + target_version = registry.latest_version() + return self.plan(current_version=current_version, target_version=target_version, registry=registry) + + def _build_steps( + self, + current_version: int, + target_version: int, + registry: MigrationRegistry, + ) -> List[MigrationStep]: + """按顺序拼装迁移步骤链。 + + Args: + current_version: 当前数据库版本。 + target_version: 目标数据库版本。 + registry: 迁移步骤注册表。 + + Returns: + List[MigrationStep]: 按顺序执行的迁移步骤列表。 + + Raises: + MissingMigrationStepError: 当中间某一版本缺少迁移步骤时抛出。 + """ + planned_steps: List[MigrationStep] = [] + next_version = current_version + + while next_version < target_version: + step = registry.get_step(next_version) + if step is None: + raise MissingMigrationStepError( + f"缺少从版本 {next_version} 升级到版本 {next_version + 1} 的迁移步骤。" + ) + planned_steps.append(step) + next_version = step.version_to + + return planned_steps + + def _validate_version(self, version: int, field_name: str) -> None: + """校验版本号是否合法。 + + Args: + version: 待校验的版本号。 + field_name: 当前版本号对应的字段名。 + + Raises: + DatabaseMigrationPlanningError: 当版本号非法时抛出。 + """ + if version < 0: + raise DatabaseMigrationPlanningError(f"{field_name} 不能小于 0: {version}") diff --git a/src/common/database/migrations/progress.py b/src/common/database/migrations/progress.py new file mode 100644 index 00000000..4e358ed7 --- /dev/null +++ b/src/common/database/migrations/progress.py @@ -0,0 +1,272 @@ +"""数据库迁移进度展示工具。""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from datetime import timedelta +from typing import Optional + +from rich.console import Console +from rich.progress import BarColumn, Progress, ProgressColumn, Task, TaskID +from rich.text import Text + + +def _format_duration(total_seconds: Optional[float]) -> str: + """将秒数格式化为适合展示的耗时文本。 + + Args: + total_seconds: 总秒数;为空时表示暂不可用。 + + Returns: + str: 格式化后的耗时文本。 + """ + if total_seconds is None: + return "--:--:--" + safe_seconds = max(total_seconds, 0.0) + return str(timedelta(seconds=int(safe_seconds))) + + +class MigrationSummaryColumn(ProgressColumn): + """渲染数据库迁移总进度摘要列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的总进度摘要。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的摘要文本。 + """ + display_total = task.fields.get("display_total", task.total) + total_text = "?" if display_total is None else str(int(display_total)) + completed_text = str(int(task.completed)) + return Text(f"总迁移进度({completed_text}/{total_text})") + + +class MigrationSpeedColumn(ProgressColumn): + """渲染数据库迁移速度列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的速度信息。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的速度文本。 + """ + unit_name = str(task.fields.get("unit_name", "项")) + if task.speed is None or task.speed <= 0: + return Text(f"-- {unit_name}/s") + return Text(f"{task.speed:.2f} {unit_name}/s") + + +class MigrationElapsedColumn(ProgressColumn): + """渲染数据库迁移已用时间列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的已用时间。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的已用时间文本。 + """ + return Text(f"已用时间 {_format_duration(task.elapsed)}") + + +class MigrationRemainingColumn(ProgressColumn): + """渲染数据库迁移预估剩余时间列。""" + + def render(self, task: Task) -> Text: + """渲染当前任务的预估剩余时间。 + + Args: + task: 当前进度任务对象。 + + Returns: + Text: 渲染后的预估剩余时间文本。 + """ + return Text(f"预估时间 {_format_duration(task.time_remaining)}") + + +class BaseMigrationProgressReporter(ABC): + """数据库迁移进度上报器基类。""" + + def __enter__(self) -> "BaseMigrationProgressReporter": + """进入进度上报上下文。 + + Returns: + BaseMigrationProgressReporter: 当前上报器实例。 + """ + self.open() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """退出进度上报上下文。 + + Args: + exc_type: 异常类型。 + exc_value: 异常实例。 + traceback: 异常追踪对象。 + """ + del exc_type, exc_value, traceback + self.close() + + @abstractmethod + def open(self) -> None: + """打开进度上报资源。""" + + @abstractmethod + def close(self) -> None: + """关闭进度上报资源。""" + + @abstractmethod + def start( + self, + total: int, + description: str = "总迁移进度", + unit_name: str = "表", + ) -> None: + """启动一个新的迁移进度任务。 + + Args: + total: 任务总数。 + description: 任务描述。 + unit_name: 进度单位名称。 + """ + + @abstractmethod + def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + """推进当前迁移进度任务。 + + Args: + advance: 本次推进的步数。 + item_name: 当前完成的项目名称。 + """ + + +class NullMigrationProgressReporter(BaseMigrationProgressReporter): + """不输出任何内容的空进度上报器。""" + + def close(self) -> None: + """关闭空进度上报器。""" + + def start( + self, + total: int, + description: str = "总迁移进度", + unit_name: str = "表", + ) -> None: + """启动空进度任务。 + + Args: + total: 任务总数。 + description: 任务描述。 + unit_name: 进度单位名称。 + """ + del total, description, unit_name + + def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + """推进空进度任务。 + + Args: + advance: 本次推进的步数。 + item_name: 当前完成的项目名称。 + """ + del advance, item_name + + +class RichMigrationProgressReporter(BaseMigrationProgressReporter): + """基于 ``rich`` 的数据库迁移进度上报器。""" + + def __init__( + self, + console: Optional[Console] = None, + disable: Optional[bool] = None, + refresh_per_second: int = 10, + ) -> None: + """初始化 ``rich`` 迁移进度上报器。 + + Args: + console: 输出使用的 ``rich`` 控制台。 + disable: 是否禁用进度条;为空时根据终端能力自动判断。 + refresh_per_second: 每秒刷新次数。 + """ + self.console = console or Console() + self.disable = disable + self.refresh_per_second = refresh_per_second + self._progress: Optional[Progress] = None + self._task_id: Optional[TaskID] = None + + def open(self) -> None: + """打开 ``rich`` 进度条资源。""" + effective_disable = not self.console.is_terminal if self.disable is None else self.disable + self._progress = Progress( + MigrationSummaryColumn(), + BarColumn(), + MigrationSpeedColumn(), + MigrationElapsedColumn(), + MigrationRemainingColumn(), + console=self.console, + transient=False, + disable=effective_disable, + refresh_per_second=self.refresh_per_second, + expand=True, + ) + self._progress.start() + + def close(self) -> None: + """关闭 ``rich`` 进度条资源。""" + if self._progress is None: + return + self._progress.stop() + self._progress = None + self._task_id = None + + def start( + self, + total: int, + description: str = "总迁移进度", + unit_name: str = "表", + ) -> None: + """启动一个新的 ``rich`` 迁移进度任务。 + + Args: + total: 任务总数。 + description: 任务描述。 + unit_name: 进度单位名称。 + """ + if self._progress is None: + self.open() + assert self._progress is not None + effective_total = max(total, 1) + self._task_id = self._progress.add_task( + description, + total=effective_total, + display_total=total, + unit_name=unit_name, + ) + + def advance(self, advance: int = 1, item_name: Optional[str] = None) -> None: + """推进当前 ``rich`` 迁移进度任务。 + + Args: + advance: 本次推进的步数。 + item_name: 当前完成的项目名称。 + """ + del item_name + if self._progress is None or self._task_id is None: + return + self._progress.update(self._task_id, advance=advance) + + +def create_rich_migration_progress_reporter() -> BaseMigrationProgressReporter: + """创建默认的 ``rich`` 迁移进度上报器。 + + Returns: + BaseMigrationProgressReporter: 默认迁移进度上报器实例。 + """ + return RichMigrationProgressReporter() diff --git a/src/common/database/migrations/registry.py b/src/common/database/migrations/registry.py new file mode 100644 index 00000000..fb9d893b --- /dev/null +++ b/src/common/database/migrations/registry.py @@ -0,0 +1,98 @@ +"""数据库迁移步骤注册表。""" + +from typing import Dict, List, Optional + +from .exceptions import DatabaseMigrationConfigurationError +from .models import MigrationStep + + +class MigrationRegistry: + """数据库迁移步骤注册表。""" + + def __init__(self, steps: Optional[List[MigrationStep]] = None) -> None: + """初始化迁移步骤注册表。 + + Args: + steps: 初始化时要注册的迁移步骤列表。 + """ + self._steps_by_from_version: Dict[int, MigrationStep] = {} + if steps: + self.register_many(steps) + + def register(self, step: MigrationStep) -> None: + """注册单个迁移步骤。 + + 当前注册表要求每个步骤只负责相邻版本间的升级,以确保迁移链路易于审计、 + 易于回放,也便于后续生产问题排查。 + + Args: + step: 待注册的迁移步骤定义。 + + Raises: + DatabaseMigrationConfigurationError: 当步骤定义冲突或版本跨度不合法时抛出。 + """ + if step.version_to != step.version_from + 1: + raise DatabaseMigrationConfigurationError( + "迁移步骤必须使用相邻版本号定义,例如 2 -> 3。" + ) + if step.version_from in self._steps_by_from_version: + existing_step = self._steps_by_from_version[step.version_from] + raise DatabaseMigrationConfigurationError( + f"版本 {step.version_from} 已存在迁移步骤: {existing_step.name}" + ) + for registered_step in self._steps_by_from_version.values(): + if registered_step.version_to == step.version_to: + raise DatabaseMigrationConfigurationError( + f"目标版本 {step.version_to} 已由迁移步骤 {registered_step.name} 占用。" + ) + self._steps_by_from_version[step.version_from] = step + + def register_many(self, steps: List[MigrationStep]) -> None: + """批量注册多个迁移步骤。 + + Args: + steps: 待注册的迁移步骤列表。 + """ + for step in steps: + self.register(step) + + def get_step(self, version_from: int) -> Optional[MigrationStep]: + """获取指定起始版本的迁移步骤。 + + Args: + version_from: 迁移步骤的起始版本号。 + + Returns: + Optional[MigrationStep]: 若存在对应步骤则返回,否则返回 ``None``。 + """ + return self._steps_by_from_version.get(version_from) + + def has_step(self, version_from: int) -> bool: + """判断指定起始版本是否已注册迁移步骤。 + + Args: + version_from: 待检查的起始版本号。 + + Returns: + bool: 若已注册对应步骤则返回 ``True``。 + """ + return version_from in self._steps_by_from_version + + def latest_version(self) -> int: + """返回当前注册表支持到的最新 schema 版本。 + + Returns: + int: 若注册表为空则返回 ``0``,否则返回最大目标版本号。 + """ + if not self._steps_by_from_version: + return 0 + return max(step.version_to for step in self._steps_by_from_version.values()) + + def list_steps(self) -> List[MigrationStep]: + """按起始版本顺序返回全部迁移步骤。 + + Returns: + List[MigrationStep]: 已注册迁移步骤列表。 + """ + ordered_versions = sorted(self._steps_by_from_version) + return [self._steps_by_from_version[version] for version in ordered_versions] diff --git a/src/common/database/migrations/resolver.py b/src/common/database/migrations/resolver.py new file mode 100644 index 00000000..fb66a57d --- /dev/null +++ b/src/common/database/migrations/resolver.py @@ -0,0 +1,135 @@ +"""数据库版本解析器。""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +from sqlalchemy.engine import Connection + +from .exceptions import DatabaseMigrationVersionError, UnrecognizedDatabaseSchemaError +from .models import DatabaseSchemaSnapshot, ResolvedSchemaVersion, SchemaVersionSource +from .schema import SQLiteSchemaInspector +from .version_store import SQLiteUserVersionStore + + +class BaseSchemaVersionDetector(ABC): + """未标记版本数据库的结构探测器基类。""" + + @property + @abstractmethod + def name(self) -> str: + """返回当前探测器名称。 + + Returns: + str: 当前探测器名称。 + """ + + @abstractmethod + def detect_version(self, snapshot: DatabaseSchemaSnapshot) -> Optional[int]: + """根据数据库结构快照推断版本号。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + Optional[int]: 若识别成功则返回版本号,否则返回 ``None``。 + """ + + +class SchemaVersionResolver: + """数据库版本解析器。""" + + def __init__( + self, + version_store: Optional[SQLiteUserVersionStore] = None, + schema_inspector: Optional[SQLiteSchemaInspector] = None, + detectors: Optional[List[BaseSchemaVersionDetector]] = None, + ) -> None: + """初始化数据库版本解析器。 + + Args: + version_store: 版本存储器;未提供时将使用默认实现。 + schema_inspector: 结构探测器;未提供时将使用默认实现。 + detectors: 未标记版本数据库的探测器列表。 + """ + self.version_store = version_store or SQLiteUserVersionStore() + self.schema_inspector = schema_inspector or SQLiteSchemaInspector() + self.detectors: List[BaseSchemaVersionDetector] = list(detectors or []) + + def add_detector(self, detector: BaseSchemaVersionDetector) -> None: + """注册一个未标记版本数据库探测器。 + + Args: + detector: 待注册的探测器实例。 + """ + self.detectors.append(detector) + + def list_detectors(self) -> List[BaseSchemaVersionDetector]: + """返回当前已注册的全部探测器。 + + Returns: + List[BaseSchemaVersionDetector]: 已注册探测器列表副本。 + """ + return list(self.detectors) + + def resolve(self, connection: Connection) -> ResolvedSchemaVersion: + """解析当前数据库的 schema 版本信息。 + + 解析顺序如下: + 1. 优先读取 ``PRAGMA user_version``。 + 2. 若其值为 0,则对数据库结构做快照。 + 3. 若数据库为空,则返回空库版本。 + 4. 若数据库非空,则交给探测器链进行识别。 + + Args: + connection: 当前数据库连接。 + + Returns: + ResolvedSchemaVersion: 解析后的数据库版本信息。 + + Raises: + DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。 + UnrecognizedDatabaseSchemaError: 当数据库非空但无法识别版本时抛出。 + """ + recorded_version = self.version_store.read_version(connection) + if recorded_version > 0: + return ResolvedSchemaVersion(version=recorded_version, source=SchemaVersionSource.PRAGMA) + + snapshot = self.schema_inspector.inspect(connection) + if snapshot.is_empty(): + return ResolvedSchemaVersion( + version=0, + source=SchemaVersionSource.EMPTY_DATABASE, + snapshot=snapshot, + ) + + return self._detect_unversioned_database(snapshot) + + def _detect_unversioned_database(self, snapshot: DatabaseSchemaSnapshot) -> ResolvedSchemaVersion: + """识别未标记版本的历史数据库。 + + Args: + snapshot: 当前数据库结构快照。 + + Returns: + ResolvedSchemaVersion: 探测器识别出的版本信息。 + + Raises: + DatabaseMigrationVersionError: 当探测器返回非法版本号时抛出。 + UnrecognizedDatabaseSchemaError: 当全部探测器都无法识别结构时抛出。 + """ + for detector in self.detectors: + detected_version = detector.detect_version(snapshot) + if detected_version is None: + continue + if detected_version < 0: + raise DatabaseMigrationVersionError( + f"探测器 {detector.name!r} 返回了非法版本号: {detected_version}" + ) + return ResolvedSchemaVersion( + version=detected_version, + source=SchemaVersionSource.DETECTOR, + detector_name=detector.name, + snapshot=snapshot, + ) + + raise UnrecognizedDatabaseSchemaError("当前数据库未记录版本号,且现有探测器无法识别其结构。") diff --git a/src/common/database/migrations/schema.py b/src/common/database/migrations/schema.py new file mode 100644 index 00000000..150b8cb7 --- /dev/null +++ b/src/common/database/migrations/schema.py @@ -0,0 +1,98 @@ +"""SQLite 数据库结构探测工具。""" + +from typing import Dict, List + +from sqlalchemy import text +from sqlalchemy.engine import Connection + +from .models import ColumnSchema, DatabaseSchemaSnapshot, TableSchema + + +class SQLiteSchemaInspector: + """SQLite 数据库结构探测器。""" + + def inspect(self, connection: Connection) -> DatabaseSchemaSnapshot: + """提取数据库中的全部用户表结构快照。 + + Args: + connection: 当前数据库连接。 + + Returns: + DatabaseSchemaSnapshot: 当前数据库结构快照。 + """ + tables: Dict[str, TableSchema] = {} + for table_name in self.list_user_tables(connection): + table_schema = self.get_table_schema(connection, table_name) + tables[table_name] = table_schema + return DatabaseSchemaSnapshot(tables=tables) + + def list_user_tables(self, connection: Connection) -> List[str]: + """列出数据库中的全部用户表。 + + Args: + connection: 当前数据库连接。 + + Returns: + List[str]: 按字母顺序排列的用户表名列表。 + """ + statement = text( + """ + SELECT name + FROM sqlite_master + WHERE type = 'table' + AND name NOT LIKE 'sqlite_%' + ORDER BY name + """ + ) + rows = connection.execute(statement).fetchall() + return [str(row[0]) for row in rows] + + def get_table_schema(self, connection: Connection, table_name: str) -> TableSchema: + """获取指定表的结构信息。 + + Args: + connection: 当前数据库连接。 + table_name: 待读取结构的表名。 + + Returns: + TableSchema: 指定表的结构快照。 + """ + quoted_table_name = self._quote_identifier(table_name) + rows = connection.exec_driver_sql(f"PRAGMA table_info({quoted_table_name})").mappings().all() + + columns: Dict[str, ColumnSchema] = {} + for row in rows: + column_schema = ColumnSchema( + name=str(row["name"]), + declared_type=str(row["type"] or ""), + default_value=None if row["dflt_value"] is None else str(row["dflt_value"]), + is_not_null=bool(row["notnull"]), + primary_key_position=int(row["pk"]), + ) + columns[column_schema.name] = column_schema + + return TableSchema(name=table_name, columns=columns) + + def table_exists(self, connection: Connection, table_name: str) -> bool: + """判断数据库中是否存在指定表。 + + Args: + connection: 当前数据库连接。 + table_name: 待检查的表名。 + + Returns: + bool: 若表存在则返回 ``True``。 + """ + return table_name in self.list_user_tables(connection) + + def _quote_identifier(self, identifier: str) -> str: + """为 SQLite 标识符添加安全引号。 + + Args: + identifier: 待引用的 SQLite 标识符。 + + Returns: + str: 可直接拼接到 PRAGMA 语句中的安全标识符。 + """ + escaped_identifier = identifier.replace('"', '""') + return f'"{escaped_identifier}"' diff --git a/src/common/database/migrations/version_store.py b/src/common/database/migrations/version_store.py new file mode 100644 index 00000000..ea1e5077 --- /dev/null +++ b/src/common/database/migrations/version_store.py @@ -0,0 +1,57 @@ +"""SQLite 数据库版本存储实现。""" + +from sqlalchemy.engine import Connection + +from .exceptions import DatabaseMigrationVersionError + + +class SQLiteUserVersionStore: + """基于 ``PRAGMA user_version`` 的 SQLite 版本存储器。""" + + def read_version(self, connection: Connection) -> int: + """读取当前数据库的 schema 版本号。 + + Args: + connection: 当前数据库连接。 + + Returns: + int: 数据库记录的 schema 版本号。 + + Raises: + DatabaseMigrationVersionError: 当读取结果异常或版本号非法时抛出。 + """ + row = connection.exec_driver_sql("PRAGMA user_version").first() + if row is None or len(row) == 0: + raise DatabaseMigrationVersionError("读取 SQLite user_version 失败,返回结果为空。") + + version = row[0] + if not isinstance(version, int): + raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不是整数: {version!r}") + if version < 0: + raise DatabaseMigrationVersionError(f"读取到的 SQLite user_version 不能为负数: {version}") + return version + + def write_version(self, connection: Connection, version: int) -> None: + """写入新的 schema 版本号。 + + Args: + connection: 当前数据库连接。 + version: 待写入的 schema 版本号。 + + Raises: + DatabaseMigrationVersionError: 当版本号非法时抛出。 + """ + self._validate_version(version) + connection.exec_driver_sql(f"PRAGMA user_version = {version}") + + def _validate_version(self, version: int) -> None: + """校验版本号是否合法。 + + Args: + version: 待校验的版本号。 + + Raises: + DatabaseMigrationVersionError: 当版本号非法时抛出。 + """ + if version < 0: + raise DatabaseMigrationVersionError(f"SQLite user_version 不能小于 0: {version}")